Unverified Commit 07bb2a5f authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #952 from xvhuan/feat/billing-ledger-decouple-usage-log-20260312

feat: 解耦计费正确性与 usage_logs 批量写压
parents 417861a4 64b3f3ce
......@@ -35,6 +35,7 @@ type DashboardAggregationRepository interface {
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
}
......@@ -296,6 +297,7 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
if aggErr != nil {
......@@ -305,7 +307,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context,
if usageErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
}
if aggErr == nil && usageErr == nil {
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
if dedupErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
}
if aggErr == nil && usageErr == nil && dedupErr == nil {
s.lastRetentionCleanup.Store(now)
}
}
......
......@@ -12,12 +12,18 @@ import (
type dashboardAggregationRepoTestStub struct {
aggregateCalls int
recomputeCalls int
cleanupUsageCalls int
cleanupDedupCalls int
ensurePartitionCalls int
lastStart time.Time
lastEnd time.Time
watermark time.Time
aggregateErr error
cleanupAggregatesErr error
cleanupUsageErr error
cleanupDedupErr error
ensurePartitionErr error
}
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
......@@ -28,6 +34,7 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
}
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
s.recomputeCalls++
return s.AggregateRange(ctx, start, end)
}
......@@ -44,11 +51,18 @@ func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
s.cleanupUsageCalls++
return s.cleanupUsageErr
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
s.cleanupDedupCalls++
return s.cleanupDedupErr
}
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
s.ensurePartitionCalls++
return s.ensurePartitionErr
}
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
......@@ -90,6 +104,50 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
require.Equal(t, 1, repo.cleanupUsageCalls)
require.Equal(t, 1, repo.cleanupDedupCalls)
}
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
require.Equal(t, 1, repo.cleanupDedupCalls)
}
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Enabled: true,
IntervalSeconds: 60,
LookbackSeconds: 120,
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
UsageBillingDedupDays: 2,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.runScheduledAggregation()
require.Equal(t, 1, repo.ensurePartitionCalls)
require.Equal(t, 1, repo.aggregateCalls)
}
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
......
......@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
return nil
}
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}
......
......@@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
},
}
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
deferredService: &DeferredService{},
billingCacheService: nil,
}
svc := &GatewayService{
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
deferredService: &DeferredService{},
billingCacheService: nil,
}
account := &Account{
......@@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
},
}
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
svc := &GatewayService{
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
account := &Account{
......@@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
require.Equal(t, 5, result.usage.OutputTokens)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
rateLimitService: &RateLimitService{},
}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`,
"",
`data: {"type":"message_delta","usage":{"output_tokens":5}}`,
"",
}, "\n"))),
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
......@@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi
_ = pr.Close()
<-done
require.NoError(t, err)
require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete after timeout")
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.Equal(t, 9, result.usage.InputTokens)
......@@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
require.NoError(t, err)
require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete")
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
}
......@@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
require.NoError(t, err)
require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete after disconnect")
require.NotNil(t, result)
require.True(t, result.clientDisconnect)
require.Equal(t, 8, result.usage.InputTokens)
......
//go:build unit
package service
import (
"context"
"errors"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
return NewGatewayService(
nil,
nil,
usageRepo,
nil,
userRepo,
subRepo,
nil,
nil,
cfg,
nil,
nil,
NewBillingService(cfg, nil),
nil,
&BillingCacheService{},
nil,
nil,
&DeferredService{},
nil,
nil,
nil,
nil,
nil,
)
}
func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
svc.usageBillingRepo = billingRepo
return svc
}
type openAIRecordUsageBestEffortLogRepoStub struct {
UsageLogRepository
bestEffortErr error
createErr error
bestEffortCalls int
createCalls int
lastLog *UsageLog
lastCtxErr error
}
func (s *openAIRecordUsageBestEffortLogRepoStub) CreateBestEffort(ctx context.Context, log *UsageLog) error {
s.bestEffortCalls++
s.lastLog = log
s.lastCtxErr = ctx.Err()
return s.bestEffortErr
}
func (s *openAIRecordUsageBestEffortLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
s.createCalls++
s.lastLog = log
s.lastCtxErr = ctx.Err()
return false, s.createErr
}
func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_detached_ctx",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 501,
Quota: 100,
},
User: &User{ID: 601},
Account: &Account{ID: 701},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`))
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_payload_hash",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
RequestPayloadHash: payloadHash,
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
}
func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_payload_fallback",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
}
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_not_persisted",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 503,
Quota: 100,
},
User: &User{ID: 603},
Account: &Account{ID: 703},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 1, quotaSvc.quotaCalls)
}
func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{
Result: &ForwardResult{
RequestID: "gateway_long_context_detached_ctx",
Usage: ClaudeUsage{
InputTokens: 12,
OutputTokens: 8,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 502,
Quota: 100,
},
User: &User{ID: 602},
Account: &Account{ID: 702},
LongContextThreshold: 200000,
LongContextMultiplier: 2,
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 504},
User: &User{ID: 604},
Account: &Account{ID: 704},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "client-stable-123")
ctx = context.WithValue(ctx, ctxkey.RequestID, "req-local-ignored")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "upstream-volatile-456",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 506},
User: &User{ID: 606},
Account: &Account{ID: 706},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "client:client-stable-123", billingRepo.lastCmd.RequestID)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "client:client-stable-123", usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 507},
User: &User{ID: 607},
Account: &Account{ID: 707},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_DroppedUsageLogDoesNotSyncFallback(t *testing.T) {
usageRepo := &openAIRecordUsageBestEffortLogRepoStub{
bestEffortErr: MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")),
}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_drop_usage_log",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 508},
User: &User{ID: 608},
Account: &Account{ID: 708},
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.bestEffortCalls)
require.Equal(t, 0, usageRepo.createCalls)
}
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo)
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_billing_fail",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 505},
User: &User{ID: 605},
Account: &Account{ID: 705},
})
require.Error(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 0, usageRepo.calls)
}
This diff is collapsed.
......@@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
_ = pr.Close()
require.NoError(t, err)
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
}
......
......@@ -3,39 +3,68 @@ package service
import (
"context"
"errors"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
type openAIRecordUsageLogRepoStub struct {
UsageLogRepository
inserted bool
err error
calls int
lastLog *UsageLog
inserted bool
err error
calls int
lastLog *UsageLog
lastCtxErr error
}
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
s.calls++
s.lastLog = log
s.lastCtxErr = ctx.Err()
return s.inserted, s.err
}
type openAIRecordUsageBillingRepoStub struct {
UsageBillingRepository
result *UsageBillingApplyResult
err error
calls int
lastCmd *UsageBillingCommand
lastCtxErr error
}
func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) {
s.calls++
s.lastCmd = cmd
s.lastCtxErr = ctx.Err()
if s.err != nil {
return nil, s.err
}
if s.result != nil {
return s.result, nil
}
return &UsageBillingApplyResult{Applied: true}, nil
}
type openAIRecordUsageUserRepoStub struct {
UserRepository
deductCalls int
deductErr error
lastAmount float64
lastCtxErr error
}
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
s.deductCalls++
s.lastAmount = amount
s.lastCtxErr = ctx.Err()
return s.deductErr
}
......@@ -44,29 +73,35 @@ type openAIRecordUsageSubRepoStub struct {
incrementCalls int
incrementErr error
lastCtxErr error
}
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
s.incrementCalls++
s.lastCtxErr = ctx.Err()
return s.incrementErr
}
type openAIRecordUsageAPIKeyQuotaStub struct {
quotaCalls int
rateLimitCalls int
err error
lastAmount float64
quotaCalls int
rateLimitCalls int
err error
lastAmount float64
lastQuotaCtxErr error
lastRateLimitCtxErr error
}
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
s.quotaCalls++
s.lastAmount = cost
s.lastQuotaCtxErr = ctx.Err()
return s.err
}
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
s.rateLimitCalls++
s.lastAmount = cost
s.lastRateLimitCtxErr = ctx.Err()
return s.err
}
......@@ -93,23 +128,38 @@ func i64p(v int64) *int64 {
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
svc := NewOpenAIGatewayService(
nil,
usageRepo,
nil,
userRepo,
subRepo,
rateRepo,
nil,
cfg,
nil,
nil,
NewBillingService(cfg, nil),
nil,
&BillingCacheService{},
nil,
&DeferredService{},
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway.test",
)
return svc
}
return &OpenAIGatewayService{
usageLogRepo: usageRepo,
userRepo: userRepo,
userSubRepo: subRepo,
cfg: cfg,
billingService: NewBillingService(cfg, nil),
billingCacheService: &BillingCacheService{},
deferredService: &DeferredService{},
userGroupRateResolver: newUserGroupRateResolver(
rateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway.test",
),
}
func newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
svc.usageBillingRepo = billingRepo
return svc
}
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
......@@ -252,9 +302,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
......@@ -272,9 +323,311 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
func TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_duplicate_billing_key",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10045,
Quota: 100,
},
User: &User{ID: 20045},
Account: &Account{ID: 30045},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
require.Equal(t, 0, quotaSvc.quotaCalls)
}
func TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError(t *testing.T) {
usage := OpenAIUsage{InputTokens: 8, OutputTokens: 4}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: errors.New("usage log batch state uncertain")}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_usage_log_error",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10041},
User: &User{ID: 20041},
Account: &Account{ID: 30041},
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
func TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_not_persisted",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10043,
Quota: 100,
},
User: &User{ID: 20043},
Account: &Account{ID: 30043},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
require.Equal(t, 1, quotaSvc.quotaCalls)
}
func TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_detached_billing_ctx",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10042,
Quota: 100,
},
User: &User{ID: 20042},
Account: &Account{ID: 30042},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_detached_billing_repo_ctx",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10046},
User: &User{ID: 20046},
Account: &Account{ID: 30046},
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.NoError(t, billingRepo.lastCtxErr)
require.Equal(t, 1, usageRepo.calls)
require.NoError(t, usageRepo.lastCtxErr)
}
func TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
payloadHash := HashUsageRequestPayload([]byte(`{"model":"gpt-5","input":"hello"}`))
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "openai_payload_hash",
Usage: OpenAIUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "gpt-5",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
RequestPayloadHash: payloadHash,
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
}
func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-fallback")
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10047},
User: &User{ID: 20047},
Account: &Account{ID: 30047},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "local:req-local-fallback", billingRepo.lastCmd.RequestID)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID)
}
func TestOpenAIGatewayServiceRecordUsage_PrefersClientRequestIDOverUpstreamRequestID(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
ctx := context.WithValue(context.Background(), ctxkey.ClientRequestID, "openai-client-stable-123")
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "upstream-openai-volatile-456",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10049},
User: &User{ID: 20049},
Account: &Account{ID: 30049},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "client:openai-client-stable-123", billingRepo.lastCmd.RequestID)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "client:openai-client-stable-123", usageRepo.lastLog.RequestID)
}
func TestOpenAIGatewayServiceRecordUsage_GeneratesRequestIDWhenAllSourcesMissing(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10050},
User: &User{ID: 20050},
Account: &Account{ID: 30050},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.True(t, strings.HasPrefix(billingRepo.lastCmd.RequestID, "generated:"))
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, billingRepo.lastCmd.RequestID, usageRepo.lastLog.RequestID)
}
func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_billing_fail",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10048},
User: &User{ID: 20048},
Account: &Account{ID: 30048},
})
require.Error(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 0, usageRepo.calls)
}
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
......
......@@ -301,6 +301,7 @@ var defaultOpenAICodexSnapshotPersistThrottle = newAccountWriteThrottle(openAICo
type OpenAIGatewayService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageBillingRepo UsageBillingRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
......@@ -338,6 +339,7 @@ type OpenAIGatewayService struct {
func NewOpenAIGatewayService(
accountRepo AccountRepository,
usageLogRepo UsageLogRepository,
usageBillingRepo UsageBillingRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
......@@ -355,6 +357,7 @@ func NewOpenAIGatewayService(
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageBillingRepo: usageBillingRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
......@@ -2073,7 +2076,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
releaseUpstreamCtx()
if err != nil {
return nil, err
}
......@@ -2265,7 +2270,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
return nil, err
}
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token)
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
releaseUpstreamCtx()
if err != nil {
return nil, err
}
......@@ -2602,6 +2609,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
var firstTokenMs *int
clientDisconnected := false
sawDone := false
sawTerminalEvent := false
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
scanner := bufio.NewScanner(resp.Body)
......@@ -2621,6 +2629,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
if trimmedData == "[DONE]" {
sawDone = true
}
if openAIStreamEventIsTerminal(trimmedData) {
sawTerminalEvent = true
}
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
......@@ -2638,19 +2649,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
}
if err := scanner.Err(); err != nil {
if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
if sawTerminalEvent {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if clientDisconnected {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
account.ID,
upstreamRequestID,
err,
ctx.Err(),
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
}
if errors.Is(err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
......@@ -2664,12 +2670,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
if !clientDisconnected && !sawDone && ctx.Err() == nil {
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
logger.FromContext(ctx).With(
zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", account.ID),
zap.String("upstream_request_id", upstreamRequestID),
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
}
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
......@@ -3203,6 +3210,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sawTerminalEvent := false
sendErrorEvent := func(reason string) {
if errorEventSent || clientDisconnected {
return
......@@ -3233,22 +3241,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
}
}
if !sawTerminalEvent {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
return resultWithUsage(), nil
}
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
if scanErr == nil {
return nil, nil, false
}
if sawTerminalEvent {
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
return resultWithUsage(), nil, true
}
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage")
return resultWithUsage(), nil, true
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr)
return resultWithUsage(), nil, true
return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
}
if errors.Is(scanErr, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
......@@ -3271,6 +3284,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
dataBytes := []byte(data)
if openAIStreamEventIsTerminal(data) {
sawTerminalEvent = true
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
......@@ -3387,8 +3403,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
continue
}
if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage")
return resultWithUsage(), nil
return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
}
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
......@@ -3486,11 +3501,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
return
}
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) {
// 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。
if len(data) < 72 {
return
}
if gjson.GetBytes(data, "type").String() != "response.completed" {
eventType := gjson.GetBytes(data, "type").String()
if eventType != "response.completed" && eventType != "response.done" {
return
}
......@@ -3843,14 +3859,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
APIKeyService APIKeyQuotaUpdater
Result *OpenAIForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string
APIKeyService APIKeyQuotaUpdater
}
// RecordUsage records usage and deducts balance
......@@ -3916,11 +3933,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
RequestID: requestID,
Model: billingModel,
ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort,
......@@ -3961,29 +3979,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID
}
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
shouldBill := inserted || err != nil
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
billingErr := func() error {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
Cost: cost,
User: user,
APIKey: apiKey,
Account: account,
Subscription: subscription,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService,
}, s.billingDeps())
} else {
s.deferredService.ScheduleLastUsedUpdate(account.ID)
}, s.billingDeps(), s.usageBillingRepo)
return err
}()
if billingErr != nil {
return billingErr
}
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
return nil
}
......
......@@ -392,6 +392,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
cfg,
nil,
nil,
......
This diff is collapsed.
This diff is collapsed.
-- 窄表账务幂等键:将“是否已扣费”从 usage_logs 解耦出来
-- 幂等执行:可重复运行
CREATE TABLE IF NOT EXISTS usage_billing_dedup (
id BIGSERIAL PRIMARY KEY,
request_id VARCHAR(255) NOT NULL,
api_key_id BIGINT NOT NULL,
request_fingerprint VARCHAR(64) NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_billing_dedup_request_api_key
ON usage_billing_dedup (request_id, api_key_id);
-- usage_billing_dedup 是按时间追加写入的幂等窄表。
-- 使用 BRIN 支撑按 created_at 的批量保留期清理,尽量降低写放大。
-- 使用 CONCURRENTLY 避免在热表上长时间阻塞写入。
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_billing_dedup_created_at_brin
ON usage_billing_dedup
USING BRIN (created_at);
-- 冷归档旧账务幂等键,缩小热表索引与清理范围,同时不丢失长期去重能力。
CREATE TABLE IF NOT EXISTS usage_billing_dedup_archive (
request_id VARCHAR(255) NOT NULL,
api_key_id BIGINT NOT NULL,
request_fingerprint VARCHAR(64) NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
archived_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (request_id, api_key_id)
);
This diff is collapsed.
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment