"frontend/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "d1b684b7823e2c1e838505cf8d9b77b409921879"
Commit 611fd884 authored by ius's avatar ius
Browse files

feat: decouple billing correctness from usage log batching

parent c9debc50
...@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut ...@@ -124,6 +124,10 @@ func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cut
return nil return nil
} }
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil return nil
} }
......
...@@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd ...@@ -136,16 +136,18 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
}, },
} }
svc := &GatewayService{ cfg := &config.Config{
cfg: &config.Config{ Gateway: config.GatewayConfig{
Gateway: config.GatewayConfig{ MaxLineSize: defaultMaxLineSize,
MaxLineSize: defaultMaxLineSize,
},
}, },
httpUpstream: upstream, }
rateLimitService: &RateLimitService{}, svc := &GatewayService{
deferredService: &DeferredService{}, cfg: cfg,
billingCacheService: nil, responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
deferredService: &DeferredService{},
billingCacheService: nil,
} }
account := &Account{ account := &Account{
...@@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo ...@@ -221,14 +223,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
}, },
} }
svc := &GatewayService{ cfg := &config.Config{
cfg: &config.Config{ Gateway: config.GatewayConfig{
Gateway: config.GatewayConfig{ MaxLineSize: defaultMaxLineSize,
MaxLineSize: defaultMaxLineSize,
},
}, },
httpUpstream: upstream, }
rateLimitService: &RateLimitService{}, svc := &GatewayService{
cfg: cfg,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
} }
account := &Account{ account := &Account{
...@@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf ...@@ -727,6 +731,39 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAf
require.Equal(t, 5, result.usage.OutputTokens) require.Equal(t, 5, result.usage.OutputTokens)
} }
func TestGatewayService_AnthropicAPIKeyPassthrough_MissingTerminalEventReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
svc := &GatewayService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
},
rateLimitService: &RateLimitService{},
}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(strings.Join([]string{
`data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`,
"",
`data: {"type":"message_delta","usage":{"output_tokens":5}}`,
"",
}, "\n"))),
}
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219")
require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) { func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
...@@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi ...@@ -1074,7 +1111,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDi
_ = pr.Close() _ = pr.Close()
<-done <-done
require.NoError(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete after timeout")
require.NotNil(t, result) require.NotNil(t, result)
require.True(t, result.clientDisconnect) require.True(t, result.clientDisconnect)
require.Equal(t, 9, result.usage.InputTokens) require.Equal(t, 9, result.usage.InputTokens)
...@@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t ...@@ -1103,7 +1141,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *t
} }
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219") result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219")
require.NoError(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete")
require.NotNil(t, result) require.NotNil(t, result)
require.True(t, result.clientDisconnect) require.True(t, result.clientDisconnect)
} }
...@@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft ...@@ -1133,7 +1172,8 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAft
} }
result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219") result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219")
require.NoError(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "stream usage incomplete after disconnect")
require.NotNil(t, result) require.NotNil(t, result)
require.True(t, result.clientDisconnect) require.True(t, result.clientDisconnect)
require.Equal(t, 8, result.usage.InputTokens) require.Equal(t, 8, result.usage.InputTokens)
......
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1
return NewGatewayService(
nil,
nil,
usageRepo,
nil,
userRepo,
subRepo,
nil,
nil,
cfg,
nil,
nil,
NewBillingService(cfg, nil),
nil,
&BillingCacheService{},
nil,
nil,
&DeferredService{},
nil,
nil,
nil,
nil,
nil,
)
}
func newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *GatewayService {
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
svc.usageBillingRepo = billingRepo
return svc
}
func TestGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_detached_ctx",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 501,
Quota: 100,
},
User: &User{ID: 601},
Account: &Account{ID: 701},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
payloadHash := HashUsageRequestPayload([]byte(`{"messages":[{"role":"user","content":"hello"}]}`))
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_payload_hash",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
RequestPayloadHash: payloadHash,
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
}
func TestGatewayServiceRecordUsage_BillingFingerprintFallsBackToContextRequestID(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-123")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_payload_fallback",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "local:req-local-123", billingRepo.lastCmd.RequestPayloadHash)
}
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_not_persisted",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 503,
Quota: 100,
},
User: &User{ID: 603},
Account: &Account{ID: 703},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 1, quotaSvc.quotaCalls)
}
func TestGatewayServiceRecordUsageWithLongContext_BillingUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsageWithLongContext(reqCtx, &RecordUsageLongContextInput{
Result: &ForwardResult{
RequestID: "gateway_long_context_detached_ctx",
Usage: ClaudeUsage{
InputTokens: 12,
OutputTokens: 8,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 502,
Quota: 100,
},
User: &User{ID: 602},
Account: &Account{ID: 702},
LongContextThreshold: 200000,
LongContextMultiplier: 2,
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestGatewayServiceRecordUsage_UsesFallbackRequestIDForUsageLog(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "gateway-local-fallback")
err := svc.RecordUsage(ctx, &RecordUsageInput{
Result: &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 504},
User: &User{ID: 604},
Account: &Account{ID: 704},
})
require.NoError(t, err)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "local:gateway-local-fallback", usageRepo.lastLog.RequestID)
}
func TestGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newGatewayRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo)
err := svc.RecordUsage(context.Background(), &RecordUsageInput{
Result: &ForwardResult{
RequestID: "gateway_billing_fail",
Usage: ClaudeUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "claude-sonnet-4",
Duration: time.Second,
},
APIKey: &APIKey{ID: 505},
User: &User{ID: 605},
Account: &Account{ID: 705},
})
require.Error(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 0, usageRepo.calls)
}
...@@ -50,6 +50,7 @@ const ( ...@@ -50,6 +50,7 @@ const (
defaultUserGroupRateCacheTTL = 30 * time.Second defaultUserGroupRateCacheTTL = 30 * time.Second
defaultModelsListCacheTTL = 15 * time.Second defaultModelsListCacheTTL = 15 * time.Second
postUsageBillingTimeout = 15 * time.Second
) )
const ( const (
...@@ -106,6 +107,52 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { ...@@ -106,6 +107,52 @@ func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) {
return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load()
} }
func claudeUsageHasAnyTokens(usage *ClaudeUsage) bool {
return usage != nil && (usage.InputTokens > 0 ||
usage.OutputTokens > 0 ||
usage.CacheCreationInputTokens > 0 ||
usage.CacheReadInputTokens > 0 ||
usage.CacheCreation5mTokens > 0 ||
usage.CacheCreation1hTokens > 0)
}
func openAIUsageHasAnyTokens(usage *OpenAIUsage) bool {
return usage != nil && (usage.InputTokens > 0 ||
usage.OutputTokens > 0 ||
usage.CacheCreationInputTokens > 0 ||
usage.CacheReadInputTokens > 0)
}
func openAIStreamEventIsTerminal(data string) bool {
trimmed := strings.TrimSpace(data)
if trimmed == "" {
return false
}
if trimmed == "[DONE]" {
return true
}
switch gjson.Get(trimmed, "type").String() {
case "response.completed", "response.done", "response.failed":
return true
default:
return false
}
}
func anthropicStreamEventIsTerminal(eventName, data string) bool {
if strings.EqualFold(strings.TrimSpace(eventName), "message_stop") {
return true
}
trimmed := strings.TrimSpace(data)
if trimmed == "" {
return false
}
if trimmed == "[DONE]" {
return true
}
return gjson.Get(trimmed, "type").String() == "message_stop"
}
func cloneStringSlice(src []string) []string { func cloneStringSlice(src []string) []string {
if len(src) == 0 { if len(src) == 0 {
return nil return nil
...@@ -504,6 +551,7 @@ type GatewayService struct { ...@@ -504,6 +551,7 @@ type GatewayService struct {
accountRepo AccountRepository accountRepo AccountRepository
groupRepo GroupRepository groupRepo GroupRepository
usageLogRepo UsageLogRepository usageLogRepo UsageLogRepository
usageBillingRepo UsageBillingRepository
userRepo UserRepository userRepo UserRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository userGroupRateRepo UserGroupRateRepository
...@@ -537,6 +585,7 @@ func NewGatewayService( ...@@ -537,6 +585,7 @@ func NewGatewayService(
accountRepo AccountRepository, accountRepo AccountRepository,
groupRepo GroupRepository, groupRepo GroupRepository,
usageLogRepo UsageLogRepository, usageLogRepo UsageLogRepository,
usageBillingRepo UsageBillingRepository,
userRepo UserRepository, userRepo UserRepository,
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository, userGroupRateRepo UserGroupRateRepository,
...@@ -563,6 +612,7 @@ func NewGatewayService( ...@@ -563,6 +612,7 @@ func NewGatewayService(
accountRepo: accountRepo, accountRepo: accountRepo,
groupRepo: groupRepo, groupRepo: groupRepo,
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,
usageBillingRepo: usageBillingRepo,
userRepo: userRepo, userRepo: userRepo,
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo, userGroupRateRepo: userGroupRateRepo,
...@@ -4049,7 +4099,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4049,7 +4099,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart := time.Now() retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ { for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -4127,7 +4179,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4127,7 +4179,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text. // also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body) filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) retryCtx, releaseRetryCtx := detachStreamUpstreamContext(ctx, reqStream)
retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
releaseRetryCtx()
if buildErr == nil { if buildErr == nil {
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil { if retryErr == nil {
...@@ -4159,7 +4213,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4159,7 +4213,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) retryCtx2, releaseRetryCtx2 := detachStreamUpstreamContext(ctx, reqStream)
retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
releaseRetryCtx2()
if buildErr2 == nil { if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr2 == nil { if retryErr2 == nil {
...@@ -4226,7 +4282,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4226,7 +4282,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
rectifiedBody, applied := RectifyThinkingBudget(body) rectifiedBody, applied := RectifyThinkingBudget(body)
if applied && time.Since(retryStart) < maxRetryElapsed { if applied && time.Since(retryStart) < maxRetryElapsed {
logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens) logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens)
budgetRetryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) budgetRetryCtx, releaseBudgetRetryCtx := detachStreamUpstreamContext(ctx, reqStream)
budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
releaseBudgetRetryCtx()
if buildErr == nil { if buildErr == nil {
budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil { if retryErr == nil {
...@@ -4498,7 +4556,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( ...@@ -4498,7 +4556,9 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
var resp *http.Response var resp *http.Response
retryStart := time.Now() retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ { for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token)
releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -4774,6 +4834,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( ...@@ -4774,6 +4834,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
usage := &ClaudeUsage{} usage := &ClaudeUsage{}
var firstTokenMs *int var firstTokenMs *int
clientDisconnected := false clientDisconnected := false
sawTerminalEvent := false
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize maxLineSize := defaultMaxLineSize
...@@ -4836,17 +4897,20 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( ...@@ -4836,17 +4897,20 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
// 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。
flusher.Flush() flusher.Flush()
} }
if !sawTerminalEvent {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
} }
if ev.err != nil { if ev.err != nil {
if sawTerminalEvent {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
}
if clientDisconnected { if clientDisconnected {
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
} }
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v", return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err)
account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err())
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
} }
if errors.Is(ev.err, bufio.ErrTooLong) { if errors.Is(ev.err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
...@@ -4858,11 +4922,19 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( ...@@ -4858,11 +4922,19 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
line := ev.line line := ev.line
if data, ok := extractAnthropicSSEDataLine(line); ok { if data, ok := extractAnthropicSSEDataLine(line); ok {
trimmed := strings.TrimSpace(data) trimmed := strings.TrimSpace(data)
if anthropicStreamEventIsTerminal("", trimmed) {
sawTerminalEvent = true
}
if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds()) ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms firstTokenMs = &ms
} }
s.parseSSEUsagePassthrough(data, usage) s.parseSSEUsagePassthrough(data, usage)
} else {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") {
sawTerminalEvent = true
}
} }
if !clientDisconnected { if !clientDisconnected {
...@@ -4884,8 +4956,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( ...@@ -4884,8 +4956,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
continue continue
} }
if clientDisconnected { if clientDisconnected {
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
} }
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
if s.rateLimitService != nil { if s.rateLimitService != nil {
...@@ -6011,6 +6082,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -6011,6 +6082,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
sawTerminalEvent := false
pendingEventLines := make([]string, 0, 4) pendingEventLines := make([]string, 0, 4)
...@@ -6041,6 +6113,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -6041,6 +6113,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
if dataLine == "[DONE]" { if dataLine == "[DONE]" {
sawTerminalEvent = true
block := "" block := ""
if eventName != "" { if eventName != "" {
block = "event: " + eventName + "\n" block = "event: " + eventName + "\n"
...@@ -6107,6 +6180,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -6107,6 +6180,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
usagePatch := s.extractSSEUsagePatch(event) usagePatch := s.extractSSEUsagePatch(event)
if anthropicStreamEventIsTerminal(eventName, dataLine) {
sawTerminalEvent = true
}
if !eventChanged { if !eventChanged {
block := "" block := ""
if eventName != "" { if eventName != "" {
...@@ -6140,18 +6216,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -6140,18 +6216,22 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
case ev, ok := <-events: case ev, ok := <-events:
if !ok { if !ok {
// 上游完成,返回结果 // 上游完成,返回结果
if !sawTerminalEvent {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event")
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
} }
if ev.err != nil { if ev.err != nil {
if sawTerminalEvent {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
}
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
} }
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
if clientDisconnected { if clientDisconnected {
logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
} }
// 客户端未断开,正常的错误处理 // 客户端未断开,正常的错误处理
if errors.Is(ev.err, bufio.ErrTooLong) { if errors.Is(ev.err, bufio.ErrTooLong) {
...@@ -6209,9 +6289,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -6209,9 +6289,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
continue continue
} }
if clientDisconnected { if clientDisconnected {
// 客户端已断开,上游也超时了,返回已收集的 usage return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout")
logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
} }
logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态 // 处理流超时,可能标记账户为临时不可调度或错误状态
...@@ -6557,15 +6635,16 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, ...@@ -6557,15 +6635,16 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct { type RecordUsageInput struct {
Result *ForwardResult Result *ForwardResult
APIKey *APIKey APIKey *APIKey
User *User User *User
Account *Account Account *Account
Subscription *UserSubscription // 可选:订阅信息 Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址 IPAddress string // 请求的客户端 IP 地址
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
} }
// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage // APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage
...@@ -6574,6 +6653,14 @@ type APIKeyQuotaUpdater interface { ...@@ -6574,6 +6653,14 @@ type APIKeyQuotaUpdater interface {
UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error
} }
type apiKeyAuthCacheInvalidator interface {
InvalidateAuthCacheByKey(ctx context.Context, key string)
}
type usageLogBestEffortWriter interface {
CreateBestEffort(ctx context.Context, log *UsageLog) error
}
// postUsageBillingParams 统一扣费所需的参数 // postUsageBillingParams 统一扣费所需的参数
type postUsageBillingParams struct { type postUsageBillingParams struct {
Cost *CostBreakdown Cost *CostBreakdown
...@@ -6581,6 +6668,7 @@ type postUsageBillingParams struct { ...@@ -6581,6 +6668,7 @@ type postUsageBillingParams struct {
APIKey *APIKey APIKey *APIKey
Account *Account Account *Account
Subscription *UserSubscription Subscription *UserSubscription
RequestPayloadHash string
IsSubscriptionBill bool IsSubscriptionBill bool
AccountRateMultiplier float64 AccountRateMultiplier float64
APIKeyService APIKeyQuotaUpdater APIKeyService APIKeyQuotaUpdater
...@@ -6592,19 +6680,22 @@ type postUsageBillingParams struct { ...@@ -6592,19 +6680,22 @@ type postUsageBillingParams struct {
// - API Key 限速用量更新 // - API Key 限速用量更新
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率) // - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
billingCtx, cancel := detachedBillingContext(ctx)
defer cancel()
cost := p.Cost cost := p.Cost
// 1. 订阅 / 余额扣费 // 1. 订阅 / 余额扣费
if p.IsSubscriptionBill { if p.IsSubscriptionBill {
if cost.TotalCost > 0 { if cost.TotalCost > 0 {
if err := deps.userSubRepo.IncrementUsage(ctx, p.Subscription.ID, cost.TotalCost); err != nil { if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
} }
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost) deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
} }
} else { } else {
if cost.ActualCost > 0 { if cost.ActualCost > 0 {
if err := deps.userRepo.DeductBalance(ctx, p.User.ID, cost.ActualCost); err != nil { if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
} }
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost) deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
...@@ -6613,31 +6704,187 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill ...@@ -6613,31 +6704,187 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
// 2. API Key 配额 // 2. API Key 配额
if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil { if cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
if err := p.APIKeyService.UpdateQuotaUsed(ctx, p.APIKey.ID, cost.ActualCost); err != nil { if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
} }
} }
// 3. API Key 限速用量 // 3. API Key 限速用量
if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil { if cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
if err := p.APIKeyService.UpdateRateLimitUsage(ctx, p.APIKey.ID, cost.ActualCost); err != nil { if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
} }
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, cost.ActualCost)
} }
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率) // 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() { if cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
accountCost := cost.TotalCost * p.AccountRateMultiplier accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(ctx, p.Account.ID, accountCost); err != nil { if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err)
} }
} }
// 5. 更新账号最近使用时间 finalizePostUsageBilling(p, deps)
}
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" {
return requestID
}
if ctx != nil {
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
return "client:" + strings.TrimSpace(clientRequestID)
}
if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
return "local:" + strings.TrimSpace(requestID)
}
}
return ""
}
func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string {
if payloadHash := strings.TrimSpace(requestPayloadHash); payloadHash != "" {
return payloadHash
}
if ctx != nil {
if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" {
return "client:" + strings.TrimSpace(clientRequestID)
}
if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" {
return "local:" + strings.TrimSpace(requestID)
}
}
return ""
}
func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsageBillingParams) *UsageBillingCommand {
if p == nil || p.Cost == nil || p.APIKey == nil || p.User == nil || p.Account == nil {
return nil
}
cmd := &UsageBillingCommand{
RequestID: requestID,
APIKeyID: p.APIKey.ID,
UserID: p.User.ID,
AccountID: p.Account.ID,
AccountType: p.Account.Type,
RequestPayloadHash: strings.TrimSpace(p.RequestPayloadHash),
}
if usageLog != nil {
cmd.Model = usageLog.Model
cmd.BillingType = usageLog.BillingType
cmd.InputTokens = usageLog.InputTokens
cmd.OutputTokens = usageLog.OutputTokens
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
cmd.CacheReadTokens = usageLog.CacheReadTokens
cmd.ImageCount = usageLog.ImageCount
if usageLog.MediaType != nil {
cmd.MediaType = *usageLog.MediaType
}
if usageLog.ServiceTier != nil {
cmd.ServiceTier = *usageLog.ServiceTier
}
if usageLog.ReasoningEffort != nil {
cmd.ReasoningEffort = *usageLog.ReasoningEffort
}
if usageLog.SubscriptionID != nil {
cmd.SubscriptionID = usageLog.SubscriptionID
}
}
if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 {
cmd.SubscriptionID = &p.Subscription.ID
cmd.SubscriptionCost = p.Cost.TotalCost
} else if p.Cost.ActualCost > 0 {
cmd.BalanceCost = p.Cost.ActualCost
}
if p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil {
cmd.APIKeyQuotaCost = p.Cost.ActualCost
}
if p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil {
cmd.APIKeyRateLimitCost = p.Cost.ActualCost
}
if p.Cost.TotalCost > 0 && p.Account.Type == AccountTypeAPIKey && p.Account.HasAnyQuotaLimit() {
cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier
}
cmd.Normalize()
return cmd
}
func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog, p *postUsageBillingParams, deps *billingDeps, repo UsageBillingRepository) (bool, error) {
if p == nil || deps == nil {
return false, nil
}
cmd := buildUsageBillingCommand(requestID, usageLog, p)
if cmd == nil || cmd.RequestID == "" || repo == nil {
postUsageBilling(ctx, p, deps)
return true, nil
}
billingCtx, cancel := detachedBillingContext(ctx)
defer cancel()
result, err := repo.Apply(billingCtx, cmd)
if err != nil {
return false, err
}
if result == nil || !result.Applied {
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
return false, nil
}
if result.APIKeyQuotaExhausted {
if invalidator, ok := p.APIKeyService.(apiKeyAuthCacheInvalidator); ok && p.APIKey != nil && p.APIKey.Key != "" {
invalidator.InvalidateAuthCacheByKey(billingCtx, p.APIKey.Key)
}
}
finalizePostUsageBilling(p, deps)
return true, nil
}
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
if p == nil || p.Cost == nil || deps == nil {
return
}
if p.IsSubscriptionBill {
if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost)
}
} else if p.Cost.ActualCost > 0 && p.User != nil {
deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost)
}
if p.Cost.ActualCost > 0 && p.APIKey != nil && p.APIKey.HasRateLimits() {
deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, p.Cost.ActualCost)
}
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
} }
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
base := context.Background()
if ctx != nil {
base = context.WithoutCancel(ctx)
}
return context.WithTimeout(base, postUsageBillingTimeout)
}
func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) {
if !stream {
return ctx, func() {}
}
if ctx == nil {
return context.Background(), func() {}
}
return context.WithoutCancel(ctx), func() {}
}
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) // billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
type billingDeps struct { type billingDeps struct {
accountRepo AccountRepository accountRepo AccountRepository
...@@ -6657,6 +6904,28 @@ func (s *GatewayService) billingDeps() *billingDeps { ...@@ -6657,6 +6904,28 @@ func (s *GatewayService) billingDeps() *billingDeps {
} }
} }
func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usageLog *UsageLog, logKey string) {
if repo == nil || usageLog == nil {
return
}
usageCtx, cancel := detachedBillingContext(ctx)
defer cancel()
if writer, ok := repo.(usageLogBestEffortWriter); ok {
if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil {
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil {
logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr)
}
}
return
}
if _, err := repo.Create(usageCtx, usageLog); err != nil {
logger.LegacyPrintf(logKey, "Create usage log failed: %v", err)
}
}
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result result := input.Result
...@@ -6758,11 +7027,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -6758,11 +7027,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
mediaType = &result.MediaType mediaType = &result.MediaType
} }
accountRateMultiplier := account.BillingRateMultiplier() accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: result.RequestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
InputTokens: result.Usage.InputTokens, InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
...@@ -6807,33 +7077,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -6807,33 +7077,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog.SubscriptionID = &subscription.ID usageLog.SubscriptionID = &subscription.ID
} }
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if err != nil {
logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID) s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil return nil
} }
shouldBill := inserted || err != nil billingErr := func() error {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
Cost: cost, Cost: cost,
User: user, User: user,
APIKey: apiKey, APIKey: apiKey,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling, IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier, AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService, APIKeyService: input.APIKeyService,
}, s.billingDeps()) }, s.billingDeps(), s.usageBillingRepo)
} else { return err
s.deferredService.ScheduleLastUsedUpdate(account.ID) }()
if billingErr != nil {
return billingErr
} }
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
return nil return nil
} }
...@@ -6844,13 +7113,14 @@ type RecordUsageLongContextInput struct { ...@@ -6844,13 +7113,14 @@ type RecordUsageLongContextInput struct {
APIKey *APIKey APIKey *APIKey
User *User User *User
Account *Account Account *Account
Subscription *UserSubscription // 可选:订阅信息 Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址 IPAddress string // 请求的客户端 IP 地址
LongContextThreshold int // 长上下文阈值(如 200000) RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) LongContextThreshold int // 长上下文阈值(如 200000)
ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
APIKeyService *APIKeyService // API Key 配额服务(可选) ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换)
APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选)
} }
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
...@@ -6933,11 +7203,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -6933,11 +7203,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
imageSize = &result.ImageSize imageSize = &result.ImageSize
} }
accountRateMultiplier := account.BillingRateMultiplier() accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: result.RequestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
InputTokens: result.Usage.InputTokens, InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
...@@ -6981,33 +7252,32 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -6981,33 +7252,32 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
usageLog.SubscriptionID = &subscription.ID usageLog.SubscriptionID = &subscription.ID
} }
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if err != nil {
logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID) s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil return nil
} }
shouldBill := inserted || err != nil billingErr := func() error {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
Cost: cost, Cost: cost,
User: user, User: user,
APIKey: apiKey, APIKey: apiKey,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling, IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier, AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService, APIKeyService: input.APIKeyService,
}, s.billingDeps()) }, s.billingDeps(), s.usageBillingRepo)
} else { return err
s.deferredService.ScheduleLastUsedUpdate(account.ID) }()
if billingErr != nil {
return billingErr
} }
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
return nil return nil
} }
......
...@@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) { ...@@ -181,7 +181,8 @@ func TestHandleStreamingResponse_EmptyStream(t *testing.T) {
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false)
_ = pr.Close() _ = pr.Close()
require.NoError(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "missing terminal event")
require.NotNil(t, result) require.NotNil(t, result)
} }
......
...@@ -7,35 +7,63 @@ import ( ...@@ -7,35 +7,63 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type openAIRecordUsageLogRepoStub struct { type openAIRecordUsageLogRepoStub struct {
UsageLogRepository UsageLogRepository
inserted bool inserted bool
err error err error
calls int calls int
lastLog *UsageLog lastLog *UsageLog
lastCtxErr error
} }
func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
s.calls++ s.calls++
s.lastLog = log s.lastLog = log
s.lastCtxErr = ctx.Err()
return s.inserted, s.err return s.inserted, s.err
} }
type openAIRecordUsageBillingRepoStub struct {
UsageBillingRepository
result *UsageBillingApplyResult
err error
calls int
lastCmd *UsageBillingCommand
lastCtxErr error
}
func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error) {
s.calls++
s.lastCmd = cmd
s.lastCtxErr = ctx.Err()
if s.err != nil {
return nil, s.err
}
if s.result != nil {
return s.result, nil
}
return &UsageBillingApplyResult{Applied: true}, nil
}
type openAIRecordUsageUserRepoStub struct { type openAIRecordUsageUserRepoStub struct {
UserRepository UserRepository
deductCalls int deductCalls int
deductErr error deductErr error
lastAmount float64 lastAmount float64
lastCtxErr error
} }
func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error { func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
s.deductCalls++ s.deductCalls++
s.lastAmount = amount s.lastAmount = amount
s.lastCtxErr = ctx.Err()
return s.deductErr return s.deductErr
} }
...@@ -44,29 +72,35 @@ type openAIRecordUsageSubRepoStub struct { ...@@ -44,29 +72,35 @@ type openAIRecordUsageSubRepoStub struct {
incrementCalls int incrementCalls int
incrementErr error incrementErr error
lastCtxErr error
} }
func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
s.incrementCalls++ s.incrementCalls++
s.lastCtxErr = ctx.Err()
return s.incrementErr return s.incrementErr
} }
type openAIRecordUsageAPIKeyQuotaStub struct { type openAIRecordUsageAPIKeyQuotaStub struct {
quotaCalls int quotaCalls int
rateLimitCalls int rateLimitCalls int
err error err error
lastAmount float64 lastAmount float64
lastQuotaCtxErr error
lastRateLimitCtxErr error
} }
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
s.quotaCalls++ s.quotaCalls++
s.lastAmount = cost s.lastAmount = cost
s.lastQuotaCtxErr = ctx.Err()
return s.err return s.err
} }
func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
s.rateLimitCalls++ s.rateLimitCalls++
s.lastAmount = cost s.lastAmount = cost
s.lastRateLimitCtxErr = ctx.Err()
return s.err return s.err
} }
...@@ -93,23 +127,38 @@ func i64p(v int64) *int64 { ...@@ -93,23 +127,38 @@ func i64p(v int64) *int64 {
func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService { func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
cfg := &config.Config{} cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.1 cfg.Default.RateMultiplier = 1.1
svc := NewOpenAIGatewayService(
nil,
usageRepo,
nil,
userRepo,
subRepo,
rateRepo,
nil,
cfg,
nil,
nil,
NewBillingService(cfg, nil),
nil,
&BillingCacheService{},
nil,
&DeferredService{},
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway.test",
)
return svc
}
return &OpenAIGatewayService{ func newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo UsageLogRepository, billingRepo UsageBillingRepository, userRepo UserRepository, subRepo UserSubscriptionRepository, rateRepo UserGroupRateRepository) *OpenAIGatewayService {
usageLogRepo: usageRepo, svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, rateRepo)
userRepo: userRepo, svc.usageBillingRepo = billingRepo
userSubRepo: subRepo, return svc
cfg: cfg,
billingService: NewBillingService(cfg, nil),
billingCacheService: &BillingCacheService{},
deferredService: &DeferredService{},
userGroupRateResolver: newUserGroupRateResolver(
rateRepo,
nil,
resolveUserGroupRateCacheTTL(cfg),
nil,
"service.openai_gateway.test",
),
}
} }
func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown { func expectedOpenAICost(t *testing.T, svc *OpenAIGatewayService, model string, usage OpenAIUsage, multiplier float64) *CostBreakdown {
...@@ -252,9 +301,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver ...@@ -252,9 +301,10 @@ func TestOpenAIGatewayServiceRecordUsage_FallsBackToGroupDefaultRateWhenResolver
func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false} usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
userRepo := &openAIRecordUsageUserRepoStub{} userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{ Result: &OpenAIForwardResult{
...@@ -272,11 +322,254 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin ...@@ -272,11 +322,254 @@ func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testin
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 1, usageRepo.calls) require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls) require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls) require.Equal(t, 0, subRepo.incrementCalls)
} }
func TestOpenAIGatewayServiceRecordUsage_DuplicateBillingKeySkipsBillingWithRepo(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: false}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_duplicate_billing_key",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10045,
Quota: 100,
},
User: &User{ID: 20045},
Account: &Account{ID: 30045},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 0, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
require.Equal(t, 0, quotaSvc.quotaCalls)
}
func TestOpenAIGatewayServiceRecordUsage_BillsWhenUsageLogCreateReturnsError(t *testing.T) {
usage := OpenAIUsage{InputTokens: 8, OutputTokens: 4}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: errors.New("usage log batch state uncertain")}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_usage_log_error",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10041},
User: &User{ID: 20041},
Account: &Account{ID: 30041},
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
}
func TestOpenAIGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_not_persisted",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10043,
Quota: 100,
},
User: &User{ID: 20043},
Account: &Account{ID: 30043},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, usageRepo.calls)
require.Equal(t, 1, userRepo.deductCalls)
require.Equal(t, 0, subRepo.incrementCalls)
require.Equal(t, 1, quotaSvc.quotaCalls)
}
func TestOpenAIGatewayServiceRecordUsage_BillingUsesDetachedContext(t *testing.T) {
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: context.DeadlineExceeded}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_detached_billing_ctx",
Usage: usage,
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{
ID: 10042,
Quota: 100,
},
User: &User{ID: 20042},
Account: &Account{ID: 30042},
APIKeyService: quotaSvc,
})
require.NoError(t, err)
require.Equal(t, 1, userRepo.deductCalls)
require.NoError(t, userRepo.lastCtxErr)
require.Equal(t, 1, quotaSvc.quotaCalls)
require.NoError(t, quotaSvc.lastQuotaCtxErr)
}
func TestOpenAIGatewayServiceRecordUsage_BillingRepoUsesDetachedContext(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
reqCtx, cancel := context.WithCancel(context.Background())
cancel()
err := svc.RecordUsage(reqCtx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_detached_billing_repo_ctx",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10046},
User: &User{ID: 20046},
Account: &Account{ID: 30046},
})
require.NoError(t, err)
require.Equal(t, 1, billingRepo.calls)
require.NoError(t, billingRepo.lastCtxErr)
require.Equal(t, 1, usageRepo.calls)
require.NoError(t, usageRepo.lastCtxErr)
}
func TestOpenAIGatewayServiceRecordUsage_BillingFingerprintIncludesRequestPayloadHash(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
payloadHash := HashUsageRequestPayload([]byte(`{"model":"gpt-5","input":"hello"}`))
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "openai_payload_hash",
Usage: OpenAIUsage{
InputTokens: 10,
OutputTokens: 6,
},
Model: "gpt-5",
Duration: time.Second,
},
APIKey: &APIKey{ID: 501, Quota: 100},
User: &User{ID: 601},
Account: &Account{ID: 701},
RequestPayloadHash: payloadHash,
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, payloadHash, billingRepo.lastCmd.RequestPayloadHash)
}
func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDForBillingAndUsageLog(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{result: &UsageBillingApplyResult{Applied: true}}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
ctx := context.WithValue(context.Background(), ctxkey.RequestID, "req-local-fallback")
err := svc.RecordUsage(ctx, &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10047},
User: &User{ID: 20047},
Account: &Account{ID: 30047},
})
require.NoError(t, err)
require.NotNil(t, billingRepo.lastCmd)
require.Equal(t, "local:req-local-fallback", billingRepo.lastCmd.RequestID)
require.NotNil(t, usageRepo.lastLog)
require.Equal(t, "local:req-local-fallback", usageRepo.lastLog.RequestID)
}
func TestOpenAIGatewayServiceRecordUsage_BillingErrorSkipsUsageLogWrite(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{}
billingRepo := &openAIRecordUsageBillingRepoStub{err: errors.New("billing tx failed")}
userRepo := &openAIRecordUsageUserRepoStub{}
subRepo := &openAIRecordUsageSubRepoStub{}
svc := newOpenAIRecordUsageServiceWithBillingRepoForTest(usageRepo, billingRepo, userRepo, subRepo, nil)
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{
RequestID: "resp_billing_fail",
Usage: OpenAIUsage{
InputTokens: 8,
OutputTokens: 4,
},
Model: "gpt-5.1",
Duration: time.Second,
},
APIKey: &APIKey{ID: 10048},
User: &User{ID: 20048},
Account: &Account{ID: 30048},
})
require.Error(t, err)
require.Equal(t, 1, billingRepo.calls)
require.Equal(t, 0, usageRepo.calls)
}
func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) { func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2} usage := OpenAIUsage{InputTokens: 10, OutputTokens: 6, CacheReadInputTokens: 2}
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
......
...@@ -259,6 +259,7 @@ type openAIWSRetryMetrics struct { ...@@ -259,6 +259,7 @@ type openAIWSRetryMetrics struct {
type OpenAIGatewayService struct { type OpenAIGatewayService struct {
accountRepo AccountRepository accountRepo AccountRepository
usageLogRepo UsageLogRepository usageLogRepo UsageLogRepository
usageBillingRepo UsageBillingRepository
userRepo UserRepository userRepo UserRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
cache GatewayCache cache GatewayCache
...@@ -295,6 +296,7 @@ type OpenAIGatewayService struct { ...@@ -295,6 +296,7 @@ type OpenAIGatewayService struct {
func NewOpenAIGatewayService( func NewOpenAIGatewayService(
accountRepo AccountRepository, accountRepo AccountRepository,
usageLogRepo UsageLogRepository, usageLogRepo UsageLogRepository,
usageBillingRepo UsageBillingRepository,
userRepo UserRepository, userRepo UserRepository,
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository, userGroupRateRepo UserGroupRateRepository,
...@@ -312,6 +314,7 @@ func NewOpenAIGatewayService( ...@@ -312,6 +314,7 @@ func NewOpenAIGatewayService(
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,
usageBillingRepo: usageBillingRepo,
userRepo: userRepo, userRepo: userRepo,
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
cache: cache, cache: cache,
...@@ -2014,7 +2017,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco ...@@ -2014,7 +2017,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
// Build upstream request // Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -2206,7 +2211,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( ...@@ -2206,7 +2211,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
return nil, err return nil, err
} }
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token) upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream)
upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token)
releaseUpstreamCtx()
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -2543,6 +2550,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ...@@ -2543,6 +2550,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
var firstTokenMs *int var firstTokenMs *int
clientDisconnected := false clientDisconnected := false
sawDone := false sawDone := false
sawTerminalEvent := false
upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id"))
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
...@@ -2562,6 +2570,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ...@@ -2562,6 +2570,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
if trimmedData == "[DONE]" { if trimmedData == "[DONE]" {
sawDone = true sawDone = true
} }
if openAIStreamEventIsTerminal(trimmedData) {
sawTerminalEvent = true
}
if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" { if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds()) ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms firstTokenMs = &ms
...@@ -2579,19 +2590,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ...@@ -2579,19 +2590,14 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
} }
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
if clientDisconnected { if sawTerminalEvent {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
if clientDisconnected {
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
logger.LegacyPrintf("service.openai_gateway", return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
"[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v",
account.ID,
upstreamRequestID,
err,
ctx.Err(),
)
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
if errors.Is(err, bufio.ErrTooLong) { if errors.Is(err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
...@@ -2605,12 +2611,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ...@@ -2605,12 +2611,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
) )
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
} }
if !clientDisconnected && !sawDone && ctx.Err() == nil { if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
logger.FromContext(ctx).With( logger.FromContext(ctx).With(
zap.String("component", "service.openai_gateway"), zap.String("component", "service.openai_gateway"),
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
zap.String("upstream_request_id", upstreamRequestID), zap.String("upstream_request_id", upstreamRequestID),
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
} }
return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
...@@ -3030,6 +3037,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -3030,6 +3037,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。 // 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
errorEventSent := false errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sawTerminalEvent := false
sendErrorEvent := func(reason string) { sendErrorEvent := func(reason string) {
if errorEventSent || clientDisconnected { if errorEventSent || clientDisconnected {
return return
...@@ -3060,22 +3068,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -3060,22 +3068,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage")
} }
} }
if !sawTerminalEvent {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
return resultWithUsage(), nil return resultWithUsage(), nil
} }
handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) { handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) {
if scanErr == nil { if scanErr == nil {
return nil, nil, false return nil, nil, false
} }
if sawTerminalEvent {
logger.LegacyPrintf("service.openai_gateway", "Upstream scan ended after terminal event: %v", scanErr)
return resultWithUsage(), nil, true
}
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。 // 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) { if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) {
logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage") return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", scanErr), true
return resultWithUsage(), nil, true
} }
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected { if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr) return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", scanErr), true
return resultWithUsage(), nil, true
} }
if errors.Is(scanErr, bufio.ErrTooLong) { if errors.Is(scanErr, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr) logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr)
...@@ -3098,6 +3111,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -3098,6 +3111,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
} }
dataBytes := []byte(data) dataBytes := []byte(data)
if openAIStreamEventIsTerminal(data) {
sawTerminalEvent = true
}
// Correct Codex tool calls if needed (apply_patch -> edit, etc.) // Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
...@@ -3214,8 +3230,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -3214,8 +3230,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
continue continue
} }
if clientDisconnected { if clientDisconnected {
logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage") return resultWithUsage(), fmt.Errorf("stream usage incomplete after timeout")
return resultWithUsage(), nil
} }
logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态 // 处理流超时,可能标记账户为临时不可调度或错误状态
...@@ -3313,11 +3328,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag ...@@ -3313,11 +3328,12 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
return return
} }
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 // 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。
if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) { if len(data) < 72 {
return return
} }
if gjson.GetBytes(data, "type").String() != "response.completed" { eventType := gjson.GetBytes(data, "type").String()
if eventType != "response.completed" && eventType != "response.done" {
return return
} }
...@@ -3670,14 +3686,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel ...@@ -3670,14 +3686,15 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage // OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct { type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult Result *OpenAIForwardResult
APIKey *APIKey APIKey *APIKey
User *User User *User
Account *Account Account *Account
Subscription *UserSubscription Subscription *UserSubscription
UserAgent string // 请求的 User-Agent UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址 IPAddress string // 请求的客户端 IP 地址
APIKeyService APIKeyQuotaUpdater RequestPayloadHash string
APIKeyService APIKeyQuotaUpdater
} }
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
...@@ -3743,11 +3760,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -3743,11 +3760,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log // Create usage log
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier() accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: result.RequestID, RequestID: requestID,
Model: billingModel, Model: billingModel,
ServiceTier: result.ServiceTier, ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
...@@ -3788,29 +3806,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -3788,29 +3806,32 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID usageLog.SubscriptionID = &subscription.ID
} }
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID) s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil return nil
} }
shouldBill := inserted || err != nil billingErr := func() error {
_, err := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{
if shouldBill {
postUsageBilling(ctx, &postUsageBillingParams{
Cost: cost, Cost: cost,
User: user, User: user,
APIKey: apiKey, APIKey: apiKey,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash),
IsSubscriptionBill: isSubscriptionBilling, IsSubscriptionBill: isSubscriptionBilling,
AccountRateMultiplier: accountRateMultiplier, AccountRateMultiplier: accountRateMultiplier,
APIKeyService: input.APIKeyService, APIKeyService: input.APIKeyService,
}, s.billingDeps()) }, s.billingDeps(), s.usageBillingRepo)
} else { return err
s.deferredService.ScheduleLastUsedUpdate(account.ID) }()
if billingErr != nil {
return billingErr
} }
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
return nil return nil
} }
......
...@@ -916,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) { ...@@ -916,7 +916,7 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
} }
} }
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { func TestOpenAIStreamingContextCanceledReturnsIncompleteErrorWithoutInjectingErrorEvent(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
cfg := &config.Config{ cfg := &config.Config{
Gateway: config.GatewayConfig{ Gateway: config.GatewayConfig{
...@@ -940,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { ...@@ -940,8 +940,8 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
} }
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
if err != nil { if err == nil || !strings.Contains(err.Error(), "stream usage incomplete") {
t.Fatalf("expected nil error, got %v", err) t.Fatalf("expected incomplete stream error, got %v", err)
} }
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") { if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
...@@ -993,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) { ...@@ -993,6 +993,107 @@ func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
} }
} }
func TestOpenAIStreamingMissingTerminalEventReturnsIncompleteError(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
}()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
t.Fatalf("expected missing terminal event error, got %v", err)
}
}
func TestOpenAIStreamingPassthroughMissingTerminalEventReturnsIncompleteError(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
}()
_, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
_ = pr.Close()
if err == nil || !strings.Contains(err.Error(), "missing terminal event") {
t.Fatalf("expected missing terminal event error, got %v", err)
}
}
func TestOpenAIStreamingPassthroughResponseDoneWithoutDoneMarkerStillSucceeds(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.done\",\"response\":{\"usage\":{\"input_tokens\":2,\"output_tokens\":3,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponsePassthrough(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now())
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 2, result.usage.InputTokens)
require.Equal(t, 3, result.usage.OutputTokens)
require.Equal(t, 1, result.usage.CacheReadInputTokens)
}
func TestOpenAIStreamingTooLong(t *testing.T) { func TestOpenAIStreamingTooLong(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
cfg := &config.Config{ cfg := &config.Config{
...@@ -1124,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { ...@@ -1124,7 +1225,7 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
go func() { go func() {
defer func() { _ = pw.Close() }() defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {}\n\n")) _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{}}\n\n"))
}() }()
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
...@@ -1674,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) { ...@@ -1674,6 +1775,12 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
require.Equal(t, 3, usage.InputTokens) require.Equal(t, 3, usage.InputTokens)
require.Equal(t, 5, usage.OutputTokens) require.Equal(t, 5, usage.OutputTokens)
require.Equal(t, 2, usage.CacheReadInputTokens) require.Equal(t, 2, usage.CacheReadInputTokens)
// done 事件同样可能携带最终 usage
svc.parseSSEUsage(`{"type":"response.done","response":{"usage":{"input_tokens":13,"output_tokens":15,"input_tokens_details":{"cached_tokens":4}}}}`, usage)
require.Equal(t, 13, usage.InputTokens)
require.Equal(t, 15, usage.OutputTokens)
require.Equal(t, 4, usage.CacheReadInputTokens)
} }
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) { func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
......
...@@ -392,6 +392,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { ...@@ -392,6 +392,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
cfg, cfg,
nil, nil,
nil, nil,
......
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strings"
)
var ErrUsageBillingRequestIDRequired = errors.New("usage billing request_id is required")
var ErrUsageBillingRequestConflict = errors.New("usage billing request fingerprint conflict")
// UsageBillingCommand describes one billable request that must be applied at most once.
type UsageBillingCommand struct {
RequestID string
APIKeyID int64
RequestFingerprint string
RequestPayloadHash string
UserID int64
AccountID int64
SubscriptionID *int64
AccountType string
Model string
ServiceTier string
ReasoningEffort string
BillingType int8
InputTokens int
OutputTokens int
CacheCreationTokens int
CacheReadTokens int
ImageCount int
MediaType string
BalanceCost float64
SubscriptionCost float64
APIKeyQuotaCost float64
APIKeyRateLimitCost float64
AccountQuotaCost float64
}
func (c *UsageBillingCommand) Normalize() {
if c == nil {
return
}
c.RequestID = strings.TrimSpace(c.RequestID)
if strings.TrimSpace(c.RequestFingerprint) == "" {
c.RequestFingerprint = buildUsageBillingFingerprint(c)
}
}
func buildUsageBillingFingerprint(c *UsageBillingCommand) string {
if c == nil {
return ""
}
raw := fmt.Sprintf(
"%d|%d|%d|%s|%s|%s|%s|%d|%d|%d|%d|%d|%d|%s|%d|%0.10f|%0.10f|%0.10f|%0.10f|%0.10f",
c.UserID,
c.AccountID,
c.APIKeyID,
strings.TrimSpace(c.AccountType),
strings.TrimSpace(c.Model),
strings.TrimSpace(c.ServiceTier),
strings.TrimSpace(c.ReasoningEffort),
c.BillingType,
c.InputTokens,
c.OutputTokens,
c.CacheCreationTokens,
c.CacheReadTokens,
c.ImageCount,
strings.TrimSpace(c.MediaType),
valueOrZero(c.SubscriptionID),
c.BalanceCost,
c.SubscriptionCost,
c.APIKeyQuotaCost,
c.APIKeyRateLimitCost,
c.AccountQuotaCost,
)
if payloadHash := strings.TrimSpace(c.RequestPayloadHash); payloadHash != "" {
raw += "|" + payloadHash
}
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
func HashUsageRequestPayload(payload []byte) string {
if len(payload) == 0 {
return ""
}
sum := sha256.Sum256(payload)
return hex.EncodeToString(sum[:])
}
func valueOrZero(v *int64) int64 {
if v == nil {
return 0
}
return *v
}
type UsageBillingApplyResult struct {
Applied bool
APIKeyQuotaExhausted bool
}
type UsageBillingRepository interface {
Apply(ctx context.Context, cmd *UsageBillingCommand) (*UsageBillingApplyResult, error)
}
...@@ -56,7 +56,8 @@ type cleanupRepoStub struct { ...@@ -56,7 +56,8 @@ type cleanupRepoStub struct {
} }
type dashboardRepoStub struct { type dashboardRepoStub struct {
recomputeErr error recomputeErr error
recomputeCalls int
} }
func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error { func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
...@@ -64,6 +65,7 @@ func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time. ...@@ -64,6 +65,7 @@ func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.
} }
func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error { func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
s.recomputeCalls++
return s.recomputeErr return s.recomputeErr
} }
...@@ -83,6 +85,10 @@ func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Ti ...@@ -83,6 +85,10 @@ func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Ti
return nil return nil
} }
func (s *dashboardRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil return nil
} }
...@@ -550,13 +556,14 @@ func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) { ...@@ -550,13 +556,14 @@ func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) {
} }
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) { func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
dashboardRepo := &dashboardRepoStub{recomputeErr: errors.New("recompute failed")}
repo := &cleanupRepoStub{ repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{ deleteQueue: []cleanupDeleteResponse{
{deleted: 0}, {deleted: 0},
}, },
} }
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{ dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
DashboardAgg: config.DashboardAggregationConfig{Enabled: false}, DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
}) })
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
svc := NewUsageCleanupService(repo, nil, dashboard, cfg) svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
...@@ -573,15 +580,17 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) { ...@@ -573,15 +580,17 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
repo.mu.Lock() repo.mu.Lock()
defer repo.mu.Unlock() defer repo.mu.Unlock()
require.Len(t, repo.markSucceeded, 1) require.Len(t, repo.markSucceeded, 1)
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
} }
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) { func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
dashboardRepo := &dashboardRepoStub{}
repo := &cleanupRepoStub{ repo := &cleanupRepoStub{
deleteQueue: []cleanupDeleteResponse{ deleteQueue: []cleanupDeleteResponse{
{deleted: 0}, {deleted: 0},
}, },
} }
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{ dashboard := NewDashboardAggregationService(dashboardRepo, nil, &config.Config{
DashboardAgg: config.DashboardAggregationConfig{Enabled: true}, DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
}) })
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
...@@ -599,6 +608,7 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) { ...@@ -599,6 +608,7 @@ func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
repo.mu.Lock() repo.mu.Lock()
defer repo.mu.Unlock() defer repo.mu.Unlock()
require.Len(t, repo.markSucceeded, 1) require.Len(t, repo.markSucceeded, 1)
require.Eventually(t, func() bool { return dashboardRepo.recomputeCalls == 1 }, time.Second, 10*time.Millisecond)
} }
func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) { func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) {
......
package service
import "errors"
type usageLogCreateDisposition int
const (
usageLogCreateDispositionUnknown usageLogCreateDisposition = iota
usageLogCreateDispositionNotPersisted
)
type UsageLogCreateError struct {
err error
disposition usageLogCreateDisposition
}
func (e *UsageLogCreateError) Error() string {
if e == nil || e.err == nil {
return "usage log create error"
}
return e.err.Error()
}
func (e *UsageLogCreateError) Unwrap() error {
if e == nil {
return nil
}
return e.err
}
func MarkUsageLogCreateNotPersisted(err error) error {
if err == nil {
return nil
}
return &UsageLogCreateError{
err: err,
disposition: usageLogCreateDispositionNotPersisted,
}
}
func IsUsageLogCreateNotPersisted(err error) bool {
if err == nil {
return false
}
var target *UsageLogCreateError
if !errors.As(err, &target) {
return false
}
return target.disposition == usageLogCreateDispositionNotPersisted
}
func ShouldBillAfterUsageLogCreate(inserted bool, err error) bool {
if inserted {
return true
}
if err == nil {
return false
}
return !IsUsageLogCreateNotPersisted(err)
}
-- 窄表账务幂等键:将“是否已扣费”从 usage_logs 解耦出来
-- 幂等执行:可重复运行
CREATE TABLE IF NOT EXISTS usage_billing_dedup (
id BIGSERIAL PRIMARY KEY,
request_id VARCHAR(255) NOT NULL,
api_key_id BIGINT NOT NULL,
request_fingerprint VARCHAR(64) NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_usage_billing_dedup_request_api_key
ON usage_billing_dedup (request_id, api_key_id);
-- usage_billing_dedup 是按时间追加写入的幂等窄表。
-- 使用 BRIN 支撑按 created_at 的批量保留期清理,尽量降低写放大。
-- 使用 CONCURRENTLY 避免在热表上长时间阻塞写入。
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_billing_dedup_created_at_brin
ON usage_billing_dedup
USING BRIN (created_at);
-- 冷归档旧账务幂等键,缩小热表索引与清理范围,同时不丢失长期去重能力。
CREATE TABLE IF NOT EXISTS usage_billing_dedup_archive (
request_id VARCHAR(255) NOT NULL,
api_key_id BIGINT NOT NULL,
request_fingerprint VARCHAR(64) NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
archived_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (request_id, api_key_id)
);
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment