Commit 632035aa authored by erio's avatar erio
Browse files

feat(billing): 网关计费迁移到 CalculateCostUnified + 模型限制错误统一

- GatewayService/OpenAIGatewayService 注入 ModelPricingResolver
- RecordUsage 从旧路径迁移到 CalculateCostUnified(支持 per_request/image 模式)
- 无渠道时自动回退旧路径,保持原有行为
- 长上下文双倍计费仅在无渠道定价时生效
- CostBreakdown 新增 BillingMode 字段,使用日志记录实际计费模式
- 模型限制错误改为与"无可用账号"相同的 503 响应
parent a51e0047
......@@ -178,10 +178,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
digestSessionStore := service.NewDigestSessionStore()
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
_ = modelPricingResolver // Phase 4: 已注册,后续 Gateway 迁移时使用
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
......
......@@ -164,14 +164,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel)
}
// 渠道模型限制检查
// 渠道模型限制检查:使用原始请求模型名,因为定价列表中注册的是用户请求的模型名
if apiKey.GroupID != nil {
checkModel := reqModel
if channelMapping.Mapped {
checkModel = channelMapping.MappedModel
}
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, checkModel) {
h.errorResponse(c, http.StatusForbidden, "invalid_request_error", "Model not available in current channel: "+reqModel)
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
}
......
......@@ -162,6 +162,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // settingService
nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
)
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
......
......@@ -2224,7 +2224,8 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil,
)
}
......
......@@ -466,6 +466,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
nil, // settingService
nil, // tlsFPProfileService
nil, // channelService
nil, // resolver
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
......
......@@ -104,6 +104,7 @@ type CostBreakdown struct {
CacheReadCost float64
TotalCost float64
ActualCost float64 // 应用倍率后的实际费用
BillingMode string // 计费模式("token"/"per_request"/"image"),由 CalculateCostUnified 填充
}
// BillingService 计费服务
......@@ -439,12 +440,21 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
input.RateMultiplier = 1.0
}
var breakdown *CostBreakdown
var err error
switch resolved.Mode {
case BillingModePerRequest, BillingModeImage:
return s.calculatePerRequestCost(resolved, input)
breakdown, err = s.calculatePerRequestCost(resolved, input)
default: // BillingModeToken
return s.calculateTokenCost(resolved, input)
breakdown, err = s.calculateTokenCost(resolved, input)
}
if err == nil && breakdown != nil {
breakdown.BillingMode = string(resolved.Mode)
if breakdown.BillingMode == "" {
breakdown.BillingMode = string(BillingModeToken)
}
}
return breakdown, err
}
// calculateTokenCost 按 token 区间计费
......
......@@ -42,6 +42,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil,
nil,
nil,
nil,
)
}
......
......@@ -569,6 +569,7 @@ type GatewayService struct {
debugModelRouting atomic.Bool
debugClaudeMimic atomic.Bool
channelService *ChannelService
resolver *ModelPricingResolver
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService *TLSFingerprintProfileService
}
......@@ -599,6 +600,7 @@ func NewGatewayService(
settingService *SettingService,
tlsFPProfileService *TLSFingerprintProfileService,
channelService *ChannelService,
resolver *ModelPricingResolver,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg)
......@@ -632,6 +634,7 @@ func NewGatewayService(
responseHeaderFilter: compileResponseHeaderFilter(cfg),
tlsFPProfileService: tlsFPProfileService,
channelService: channelService,
resolver: resolver,
}
svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo,
......@@ -7790,13 +7793,21 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
// 渠道定价覆盖
var chPricing *ChannelModelPricing
if s.channelService != nil && apiKey.Group != nil {
chPricing = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel)
}
if chPricing != nil {
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing)
if s.resolver != nil && apiKey.Group != nil {
var groupID *int64
if apiKey.Group != nil {
gid := apiKey.Group.ID
groupID = &gid
}
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: groupID,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
})
} else {
cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier)
}
......@@ -7868,6 +7879,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.ImageCount > 0 {
billingMode := "image"
usageLog.BillingMode = &billingMode
} else if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := "token"
usageLog.BillingMode = &billingMode
......@@ -8016,14 +8030,30 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
// 渠道定价覆盖
var chPricing2 *ChannelModelPricing
if s.channelService != nil && apiKey.Group != nil {
chPricing2 = s.channelService.GetChannelModelPricing(ctx, apiKey.Group.ID, billingModel)
// 优先尝试 Resolver + CalculateCostUnified(仅在有渠道定价时使用)
useUnified := false
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
resolved := s.resolver.Resolve(ctx, PricingInput{
Model: billingModel,
GroupID: &gid,
})
if resolved.Source == "channel" {
// 有渠道定价,渠道区间已包含上下文分层
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
})
useUnified = true
}
}
if chPricing2 != nil {
cost, err = s.billingService.CalculateCostWithChannel(billingModel, tokens, multiplier, chPricing2)
} else {
if !useUnified {
// 无渠道定价,保持原有长上下文双倍计费逻辑(如 Gemini 200K 阈值)
cost, err = s.billingService.CalculateCostWithLongContext(billingModel, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
}
if err != nil {
......@@ -8088,6 +8118,9 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
if result.ImageCount > 0 {
billingMode := "image"
usageLog.BillingMode = &billingMode
} else if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := "token"
usageLog.BillingMode = &billingMode
......
......@@ -145,6 +145,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil,
&DeferredService{},
nil,
nil,
)
svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo,
......
......@@ -322,6 +322,7 @@ type OpenAIGatewayService struct {
openAITokenProvider *OpenAITokenProvider
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver
openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once
......@@ -357,6 +358,7 @@ func NewOpenAIGatewayService(
httpUpstream HTTPUpstream,
deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver,
) *OpenAIGatewayService {
svc := &OpenAIGatewayService{
accountRepo: accountRepo,
......@@ -384,6 +386,7 @@ func NewOpenAIGatewayService(
openAITokenProvider: openAITokenProvider,
toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver,
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
......@@ -4152,12 +4155,28 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
}
var cost *CostBreakdown
var err error
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
serviceTier := ""
if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier)
}
cost, err := s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
RequestCount: 1,
RateMultiplier: multiplier,
ServiceTier: serviceTier,
Resolver: s.resolver,
})
} else {
cost, err = s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier)
}
if err != nil {
cost = &CostBreakdown{ActualCost: 0}
}
......@@ -4204,8 +4223,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
// 设置计费模式(OpenAI 网关都是 token 计费)
{
// 设置计费模式
if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode
usageLog.BillingMode = &billingMode
} else {
billingMode := "token"
usageLog.BillingMode = &billingMode
}
......
......@@ -615,6 +615,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment