Commit 632035aa authored by erio's avatar erio
Browse files

feat(billing): 网关计费迁移到 CalculateCostUnified + 模型限制错误统一

- GatewayService/OpenAIGatewayService 注入 ModelPricingResolver
- RecordUsage 从旧路径迁移到 CalculateCostUnified(支持 per_request/image 模式)
- 无渠道时自动回退旧路径,保持原有行为
- 长上下文双倍计费仅在无渠道定价时生效
- CostBreakdown 新增 BillingMode 字段,使用日志记录实际计费模式
- 模型限制错误改为与"无可用账号"相同的 503 响应
parent a51e0047
...@@ -178,10 +178,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -178,10 +178,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
digestSessionStore := service.NewDigestSessionStore() digestSessionStore := service.NewDigestSessionStore()
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator) channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
_ = modelPricingResolver // Phase 4: 已注册,后续 Gateway 迁移时使用 gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
......
...@@ -164,14 +164,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -164,14 +164,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel) channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel)
} }
// 渠道模型限制检查 // 渠道模型限制检查:使用原始请求模型名,因为定价列表中注册的是用户请求的模型名
if apiKey.GroupID != nil { if apiKey.GroupID != nil {
checkModel := reqModel if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) {
if channelMapping.Mapped { h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
checkModel = channelMapping.MappedModel
}
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, checkModel) {
h.errorResponse(c, http.StatusForbidden, "invalid_request_error", "Model not available in current channel: "+reqModel)
return return
} }
} }
......
...@@ -162,6 +162,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi ...@@ -162,6 +162,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // settingService nil, // settingService
nil, // tlsFPProfileService nil, // tlsFPProfileService
nil, // channelService nil, // channelService
nil, // resolver
) )
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
......
...@@ -2224,7 +2224,8 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac ...@@ -2224,7 +2224,8 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService( return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil, accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil,
) )
} }
......
...@@ -466,6 +466,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { ...@@ -466,6 +466,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
nil, // settingService nil, // settingService
nil, // tlsFPProfileService nil, // tlsFPProfileService
nil, // channelService nil, // channelService
nil, // resolver
) )
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
......
...@@ -104,6 +104,7 @@ type CostBreakdown struct { ...@@ -104,6 +104,7 @@ type CostBreakdown struct {
CacheReadCost float64 CacheReadCost float64
TotalCost float64 TotalCost float64
ActualCost float64 // 应用倍率后的实际费用 ActualCost float64 // 应用倍率后的实际费用
BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充
} }
// BillingService 计费服务 // BillingService 计费服务
...@@ -439,12 +440,21 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, ...@@ -439,12 +440,21 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
input.RateMultiplier = 1.0 input.RateMultiplier = 1.0
} }
var breakdown *CostBreakdown
var err error
switch resolved.Mode { switch resolved.Mode {
case BillingModePerRequest, BillingModeImage: case BillingModePerRequest, BillingModeImage:
return s.calculatePerRequestCost(resolved, input) breakdown, err = s.calculatePerRequestCost(resolved, input)
default: // BillingModeToken default: // BillingModeToken
return s.calculateTokenCost(resolved, input) breakdown, err = s.calculateTokenCost(resolved, input)
}
if err == nil && breakdown != nil {
breakdown.BillingMode = string(resolved.Mode)
if breakdown.BillingMode == "" {
breakdown.BillingMode = string(BillingModeToken)
}
} }
return breakdown, err
} }
// calculateTokenCost 按 token 区间计费 // calculateTokenCost 按 token 区间计费
......
...@@ -42,6 +42,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo ...@@ -42,6 +42,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil, nil,
nil, nil,
nil, nil,
nil,
) )
} }
......
...@@ -569,6 +569,7 @@ type GatewayService struct { ...@@ -569,6 +569,7 @@ type GatewayService struct {
debugModelRouting atomic.Bool debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool debugClaudeMimic atomic.Bool
channelService *ChannelService channelService *ChannelService
resolver *ModelPricingResolver
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService *TLSFingerprintProfileService tlsFPProfileService *TLSFingerprintProfileService
} }
...@@ -599,6 +600,7 @@ func NewGatewayService( ...@@ -599,6 +600,7 @@ func NewGatewayService(
settingService *SettingService, settingService *SettingService,
tlsFPProfileService *TLSFingerprintProfileService, tlsFPProfileService *TLSFingerprintProfileService,
channelService *ChannelService, channelService *ChannelService,
resolver *ModelPricingResolver,
) *GatewayService { ) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg)
...@@ -632,6 +634,7 @@ func NewGatewayService( ...@@ -632,6 +634,7 @@ func NewGatewayService(
responseHeaderFilter: compileResponseHeaderFilter(cfg), responseHeaderFilter: compileResponseHeaderFilter(cfg),
tlsFPProfileService: tlsFPProfileService, tlsFPProfileService: tlsFPProfileService,
channelService: channelService, channelService: channelService,
resolver: resolver,
} }
svc.userGroupRateResolver = newUserGroupRateResolver( svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo, userGroupRateRepo,
...@@ -7790,13 +7793,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7790,13 +7793,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
} }
var err error var err error
// 渠道定价覆盖 if s.resolver != nil && apiKey.Group != nil {
var chPricing *ChannelModelPricing var groupID *int64
if s.channelService != nil && apiKey.Group != nil { if apiKey.Group != nil {
chPricing = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel) gid := apiKey.Group.ID
} groupID = &gid
if chPricing != nil { }
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing) cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: groupID,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
})
} else { } else {
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
} }
...@@ -7868,6 +7879,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7868,6 +7879,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.ImageCount > 0 { if result.ImageCount > 0 {
billingMode := "image" billingMode := "image"
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} else if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else { } else {
billingMode := "token" billingMode := "token"
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
...@@ -8016,14 +8030,30 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -8016,14 +8030,30 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
} }
var err error var err error
// 渠道定价覆盖 // 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
var chPricing2 *ChannelModelPricing useUnified := false
if s.channelService != nil && apiKey.Group != nil { if s.resolver != nil && apiKey.Group != nil {
chPricing2 = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel) gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{
Model: billingModel,
GroupID: &gid,
})
if resolved.Source == "channel" {
// 有渠道定价,渠道区间已包含上下文分层
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
})
useUnified = true
}
} }
if chPricing2 != nil { if !useUnified {
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing2) // 无渠道定价,保持原有长上下文双倍计费逻辑(如 Gemini 200K 阈值)
} else {
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
} }
if err != nil { if err != nil {
...@@ -8088,6 +8118,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * ...@@ -8088,6 +8118,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
if result.ImageCount > 0 { if result.ImageCount > 0 {
billingMode := "image" billingMode := "image"
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} else if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else { } else {
billingMode := "token" billingMode := "token"
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
......
...@@ -145,6 +145,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U ...@@ -145,6 +145,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil, nil,
&DeferredService{}, &DeferredService{},
nil, nil,
nil,
) )
svc.userGroupRateResolver = newUserGroupRateResolver( svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo, rateRepo,
......
...@@ -322,6 +322,7 @@ type OpenAIGatewayService struct { ...@@ -322,6 +322,7 @@ type OpenAIGatewayService struct {
openAITokenProvider *OpenAITokenProvider openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver
openaiWSPoolOnce sync.Once openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once openaiWSStateStoreOnce sync.Once
...@@ -357,6 +358,7 @@ func NewOpenAIGatewayService( ...@@ -357,6 +358,7 @@ func NewOpenAIGatewayService(
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
deferredService *DeferredService, deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider, openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver,
) *OpenAIGatewayService { ) *OpenAIGatewayService {
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -384,6 +386,7 @@ func NewOpenAIGatewayService( ...@@ -384,6 +386,7 @@ func NewOpenAIGatewayService(
openAITokenProvider: openAITokenProvider, openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(), toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver,
responseHeaderFilter: compileResponseHeaderFilter(cfg), responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
} }
...@@ -4152,12 +4155,28 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4152,12 +4155,28 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
} }
var cost *CostBreakdown
var err error
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
serviceTier := "" serviceTier := ""
if result.ServiceTier != nil { if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier) serviceTier = strings.TrimSpace(*result.ServiceTier)
} }
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
ServiceTier: serviceTier,
Resolver: s.resolver,
})
} else {
cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
}
if err != nil { if err != nil {
cost = &CostBreakdown{ActualCost: 0} cost = &CostBreakdown{ActualCost: 0}
} }
...@@ -4204,8 +4223,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4204,8 +4223,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
FirstTokenMs: result.FirstTokenMs, FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
// 设置计费模式(OpenAI 网关都是 token 计费) // 设置计费模式
{ if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := "token" billingMode := "token"
usageLog.BillingMode = &billingMode usageLog.BillingMode = &billingMode
} }
......
...@@ -615,6 +615,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { ...@@ -615,6 +615,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
) )
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment