Commit 8d03c52e authored by erio's avatar erio
Browse files

feat(channel): 通配符定价匹配 + OpenAI BillingModelSource + 按次价格校验 + 用户端计费模式展示

- 定价查找支持通配符(suffix *),最长前缀优先匹配
- 模型限制(restrict_models)同样支持通配符匹配
- OpenAI 网关接入渠道映射/BillingModelSource/模型限制
- 按次/图片计费模式创建时强制要求价格或层级(前后端)
- 用户使用记录列表增加计费模式 badge 列
parent 0fbc9a44
...@@ -180,7 +180,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -180,7 +180,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
......
...@@ -276,11 +276,21 @@ func (h *ChannelHandler) Create(c *gin.Context) { ...@@ -276,11 +276,21 @@ func (h *ChannelHandler) Create(c *gin.Context) {
return return
} }
pricing := pricingRequestToService(req.ModelPricing)
for _, p := range pricing {
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
response.BadRequest(c, "Per-request price or intervals required for per_request/image billing mode")
return
}
}
}
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
ModelPricing: pricingRequestToService(req.ModelPricing), ModelPricing: pricing,
ModelMapping: req.ModelMapping, ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource, BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels, RestrictModels: req.RestrictModels,
...@@ -319,6 +329,14 @@ func (h *ChannelHandler) Update(c *gin.Context) { ...@@ -319,6 +329,14 @@ func (h *ChannelHandler) Update(c *gin.Context) {
} }
if req.ModelPricing != nil { if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing) pricing := pricingRequestToService(*req.ModelPricing)
for _, p := range pricing {
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
response.BadRequest(c, "Per-request price or intervals required for per_request/image billing mode")
return
}
}
}
input.ModelPricing = &pricing input.ModelPricing = &pricing
} }
......
...@@ -185,6 +185,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -185,6 +185,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
var channelMapping service.ChannelMappingResult
if apiKey.GroupID != nil {
channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel)
}
// 渠道模型限制检查
if apiKey.GroupID != nil {
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
}
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) { if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
return return
...@@ -379,6 +393,21 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -379,6 +393,21 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: func() string {
if !channelMapping.Mapped {
if result.UpstreamModel != "" && result.UpstreamModel != result.Model {
return reqModel + "→" + result.UpstreamModel
}
return ""
}
if result.UpstreamModel != "" && result.UpstreamModel != channelMapping.MappedModel {
return reqModel + "→" + channelMapping.MappedModel + "→" + result.UpstreamModel
}
return reqModel + "→" + channelMapping.MappedModel
}(),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.responses"), zap.String("component", "handler.openai_gateway.responses"),
...@@ -549,6 +578,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -549,6 +578,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
var channelMappingMsg service.ChannelMappingResult
if apiKey.GroupID != nil {
channelMappingMsg = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel)
}
// 渠道模型限制检查
if apiKey.GroupID != nil {
if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) {
h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
}
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil { if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService) service.BindErrorPassthroughService(c, h.errorPassthroughService)
...@@ -759,6 +802,21 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -759,6 +802,21 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMappingMsg.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMappingMsg.BillingModelSource,
ModelMappingChain: func() string {
if !channelMappingMsg.Mapped {
if result.UpstreamModel != "" && result.UpstreamModel != result.Model {
return reqModel + "→" + result.UpstreamModel
}
return ""
}
if result.UpstreamModel != "" && result.UpstreamModel != channelMappingMsg.MappedModel {
return reqModel + "→" + channelMappingMsg.MappedModel + "→" + result.UpstreamModel
}
return reqModel + "→" + channelMappingMsg.MappedModel
}(),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.messages"), zap.String("component", "handler.openai_gateway.messages"),
...@@ -1101,6 +1159,20 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ...@@ -1101,6 +1159,20 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage) setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
// 解析渠道级模型映射
var channelMappingWS service.ChannelMappingResult
if apiKey.GroupID != nil {
channelMappingWS = h.gatewayService.ResolveChannelMapping(ctx, *apiKey.GroupID, reqModel)
}
// 渠道模型限制检查
if apiKey.GroupID != nil {
if h.gatewayService.IsModelRestricted(ctx, *apiKey.GroupID, reqModel) {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model not allowed")
return
}
}
var currentUserRelease func() var currentUserRelease func()
var currentAccountRelease func() var currentAccountRelease func()
releaseTurnSlots := func() { releaseTurnSlots := func() {
...@@ -1259,6 +1331,21 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ...@@ -1259,6 +1331,21 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMappingWS.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMappingWS.BillingModelSource,
ModelMappingChain: func() string {
if !channelMappingWS.Mapped {
if result.UpstreamModel != "" && result.UpstreamModel != result.Model {
return reqModel + "→" + result.UpstreamModel
}
return ""
}
if result.UpstreamModel != "" && result.UpstreamModel != channelMappingWS.MappedModel {
return reqModel + "→" + channelMappingWS.MappedModel + "→" + result.UpstreamModel
}
return reqModel + "→" + channelMappingWS.MappedModel
}(),
}); err != nil { }); err != nil {
reqLog.Error("openai.websocket_record_usage_failed", reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"log/slog" "log/slog"
"sort"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
...@@ -57,13 +58,26 @@ type channelModelKey struct { ...@@ -57,13 +58,26 @@ type channelModelKey struct {
model string // lowercase model string // lowercase
} }
// channelGroupPlatformKey 通配符定价缓存键
type channelGroupPlatformKey struct {
groupID int64
platform string
}
// wildcardPricingEntry 通配符定价条目
type wildcardPricingEntry struct {
prefix string
pricing *ChannelModelPricing
}
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找) // channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
type channelCache struct { type channelCache struct {
// 热路径查找 // 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价 pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标 wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
channelByGroupID map[int64]*Channel // groupID → 渠道 mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
groupPlatform map[int64]string // groupID → platform channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
// 冷路径(CRUD 操作) // 冷路径(CRUD 操作)
byID map[int64]*Channel byID map[int64]*Channel
...@@ -137,12 +151,13 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -137,12 +151,13 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err) slog.Warn("failed to build channel cache", "error", err)
errorCache := &channelCache{ errorCache := &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
mappingByGroupModel: make(map[channelModelKey]string), wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
channelByGroupID: make(map[int64]*Channel), mappingByGroupModel: make(map[channelModelKey]string),
groupPlatform: make(map[int64]string), channelByGroupID: make(map[int64]*Channel),
byID: make(map[int64]*Channel), groupPlatform: make(map[int64]string),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
} }
s.cache.Store(errorCache) s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err) return nil, fmt.Errorf("list all channels: %w", err)
...@@ -163,12 +178,13 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -163,12 +178,13 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
} }
cache := &channelCache{ cache := &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
mappingByGroupModel: make(map[channelModelKey]string), wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
channelByGroupID: make(map[int64]*Channel), mappingByGroupModel: make(map[channelModelKey]string),
groupPlatform: groupPlatforms, channelByGroupID: make(map[int64]*Channel),
byID: make(map[int64]*Channel, len(channels)), groupPlatform: groupPlatforms,
loadedAt: time.Now(), byID: make(map[int64]*Channel, len(channels)),
loadedAt: time.Now(),
} }
for i := range channels { for i := range channels {
...@@ -187,8 +203,18 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -187,8 +203,18 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
continue // 跳过非本平台的定价 continue // 跳过非本平台的定价
} }
for _, model := range pricing.Models { for _, model := range pricing.Models {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)} if strings.HasSuffix(model, "*") {
cache.pricingByGroupModel[key] = pricing // 通配符模型 → 存入 wildcardByGroupPlatform
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
prefix: prefix,
pricing: pricing,
})
} else {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing
}
} }
} }
...@@ -202,6 +228,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -202,6 +228,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
} }
} }
// 通配符条目按前缀长度降序排列(最长前缀优先匹配)
for gpKey, entries := range cache.wildcardByGroupPlatform {
sort.Slice(entries, func(i, j int) bool {
return len(entries[i].prefix) > len(entries[j].prefix)
})
cache.wildcardByGroupPlatform[gpKey] = entries
}
s.cache.Store(cache) s.cache.Store(cache)
return cache, nil return cache, nil
} }
...@@ -212,6 +246,18 @@ func (s *ChannelService) invalidateCache() { ...@@ -212,6 +246,18 @@ func (s *ChannelService) invalidateCache() {
s.cacheSF.Forget("channel_cache") s.cacheSF.Forget("channel_cache")
} }
// matchWildcard 在通配符定价中查找匹配项(最长前缀优先)
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardByGroupPlatform[gpKey]
for _, wc := range wildcards {
if strings.HasPrefix(modelLower, wc.prefix) {
return wc.pricing
}
}
return nil
}
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1)) // GetChannelForGroup 获取分组关联的渠道(热路径 O(1))
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) { func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
cache, err := s.loadCache(ctx) cache, err := s.loadCache(ctx)
...@@ -245,7 +291,11 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int ...@@ -245,7 +291,11 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)} key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
pricing, ok := cache.pricingByGroupModel[key] pricing, ok := cache.pricingByGroupModel[key]
if !ok { if !ok {
return nil // 精确查找失败,尝试通配符匹配
pricing = cache.matchWildcard(groupID, platform, strings.ToLower(model))
if pricing == nil {
return nil
}
} }
cp := pricing.Clone() cp := pricing.Clone()
...@@ -302,7 +352,14 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m ...@@ -302,7 +352,14 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m
platform := cache.groupPlatform[groupID] platform := cache.groupPlatform[groupID]
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)} key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)}
_, exists := cache.pricingByGroupModel[key] _, exists := cache.pricingByGroupModel[key]
return !exists if exists {
return false
}
// 精确查找失败,尝试通配符匹配
if cache.matchWildcard(groupID, platform, strings.ToLower(model)) != nil {
return false
}
return true
} }
// --- CRUD --- // --- CRUD ---
......
...@@ -146,6 +146,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U ...@@ -146,6 +146,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
&DeferredService{}, &DeferredService{},
nil, nil,
nil, nil,
nil,
) )
svc.userGroupRateResolver = newUserGroupRateResolver( svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo, rateRepo,
......
...@@ -323,6 +323,7 @@ type OpenAIGatewayService struct { ...@@ -323,6 +323,7 @@ type OpenAIGatewayService struct {
toolCorrector *CodexToolCorrector toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver resolver *ModelPricingResolver
channelService *ChannelService
openaiWSPoolOnce sync.Once openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once openaiWSStateStoreOnce sync.Once
...@@ -359,6 +360,7 @@ func NewOpenAIGatewayService( ...@@ -359,6 +360,7 @@ func NewOpenAIGatewayService(
deferredService *DeferredService, deferredService *DeferredService,
openAITokenProvider *OpenAITokenProvider, openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver, resolver *ModelPricingResolver,
channelService *ChannelService,
) *OpenAIGatewayService { ) *OpenAIGatewayService {
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -387,6 +389,7 @@ func NewOpenAIGatewayService( ...@@ -387,6 +389,7 @@ func NewOpenAIGatewayService(
toolCorrector: NewCodexToolCorrector(), toolCorrector: NewCodexToolCorrector(),
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver, resolver: resolver,
channelService: channelService,
responseHeaderFilter: compileResponseHeaderFilter(cfg), responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
} }
...@@ -394,6 +397,22 @@ func NewOpenAIGatewayService( ...@@ -394,6 +397,22 @@ func NewOpenAIGatewayService(
return svc return svc
} }
// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService)
func (s *OpenAIGatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}
}
return s.channelService.ResolveChannelMapping(ctx, groupID, model)
}
// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService)
func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
if s.channelService == nil {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
if s != nil && s.codexSnapshotThrottle != nil { if s != nil && s.codexSnapshotThrottle != nil {
return s.codexSnapshotThrottle return s.codexSnapshotThrottle
...@@ -4113,6 +4132,10 @@ type OpenAIRecordUsageInput struct { ...@@ -4113,6 +4132,10 @@ type OpenAIRecordUsageInput struct {
IPAddress string // 请求的客户端 IP 地址 IPAddress string // 请求的客户端 IP 地址
RequestPayloadHash string RequestPayloadHash string
APIKeyService APIKeyQuotaUpdater APIKeyService APIKeyQuotaUpdater
ChannelID int64
OriginalModel string
BillingModelSource string
ModelMappingChain string
} }
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
...@@ -4158,6 +4181,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4158,6 +4181,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
var cost *CostBreakdown var cost *CostBreakdown
var err error var err error
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
if result.BillingModel != "" {
billingModel = strings.TrimSpace(result.BillingModel)
}
if input.BillingModelSource == "requested" && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
serviceTier := "" serviceTier := ""
if result.ServiceTier != nil { if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier) serviceTier = strings.TrimSpace(*result.ServiceTier)
...@@ -4223,6 +4252,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4223,6 +4252,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
FirstTokenMs: result.FirstTokenMs, FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
// 设置渠道信息
usageLog.ChannelID = optionalInt64Ptr(input.ChannelID)
usageLog.ModelMappingChain = optionalTrimmedStringPtr(input.ModelMappingChain)
// 设置计费模式 // 设置计费模式
if cost != nil && cost.BillingMode != "" { if cost != nil && cost.BillingMode != "" {
billingMode := cost.BillingMode billingMode := cost.BillingMode
......
...@@ -616,6 +616,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { ...@@ -616,6 +616,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
) )
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
......
...@@ -1789,6 +1789,7 @@ export default { ...@@ -1789,6 +1789,7 @@ export default {
noTiersYet: 'No tiers yet. Click add to configure per-request pricing.', noTiersYet: 'No tiers yet. Click add to configure per-request pricing.',
noPricingRules: 'No pricing rules yet. Click "Add" to create one.', noPricingRules: 'No pricing rules yet. Click "Add" to create one.',
perRequestPrice: 'Price per Request', perRequestPrice: 'Price per Request',
perRequestPriceRequired: 'Per-request price or billing tiers required for per-request/image billing mode',
tierLabel: 'Tier', tierLabel: 'Tier',
resolution: 'Resolution', resolution: 'Resolution',
modelMapping: 'Model Mapping', modelMapping: 'Model Mapping',
......
...@@ -1869,6 +1869,7 @@ export default { ...@@ -1869,6 +1869,7 @@ export default {
noTiersYet: '暂无层级,点击添加配置按次计费价格', noTiersYet: '暂无层级,点击添加配置按次计费价格',
noPricingRules: '暂无定价规则,点击"添加"创建', noPricingRules: '暂无定价规则,点击"添加"创建',
perRequestPrice: '单次价格', perRequestPrice: '单次价格',
perRequestPriceRequired: '按次/图片计费模式必须设置默认价格或至少一个计费层级',
tierLabel: '层级', tierLabel: '层级',
resolution: '分辨率', resolution: '分辨率',
modelMapping: '模型映射', modelMapping: '模型映射',
......
...@@ -876,6 +876,19 @@ async function handleSubmit() { ...@@ -876,6 +876,19 @@ async function handleSubmit() {
return return
} }
// 校验 per_request/image 模式必须有价格
for (const section of form.platforms) {
for (const entry of section.model_pricing) {
if (entry.models.length === 0) continue
if ((entry.billing_mode === 'per_request' || entry.billing_mode === 'image') &&
(entry.per_request_price == null || entry.per_request_price === '') &&
(!entry.intervals || entry.intervals.length === 0)) {
appStore.showError(t('admin.channels.perRequestPriceRequired', '按次/图片计费模式必须设置默认价格或至少一个计费层级'))
return
}
}
}
const { group_ids, model_pricing, model_mapping } = formToAPI() const { group_ids, model_pricing, model_mapping } = formToAPI()
console.log('[handleSubmit] model_pricing to send:', JSON.stringify(model_pricing)) console.log('[handleSubmit] model_pricing to send:', JSON.stringify(model_pricing))
......
...@@ -181,6 +181,13 @@ ...@@ -181,6 +181,13 @@
</span> </span>
</template> </template>
<template #cell-billing_mode="{ row }">
<span class="inline-flex items-center rounded px-1.5 py-0.5 text-xs font-medium"
:class="getBillingModeBadgeClass(row.billing_mode)">
{{ getBillingModeLabel(row.billing_mode) }}
</span>
</template>
<template #cell-tokens="{ row }"> <template #cell-tokens="{ row }">
<!-- 图片生成请求 --> <!-- 图片生成请求 -->
<div v-if="row.image_count > 0" class="flex items-center gap-1.5"> <div v-if="row.image_count > 0" class="flex items-center gap-1.5">
...@@ -525,6 +532,7 @@ const columns = computed<Column[]>(() => [ ...@@ -525,6 +532,7 @@ const columns = computed<Column[]>(() => [
{ key: 'reasoning_effort', label: t('usage.reasoningEffort'), sortable: false }, { key: 'reasoning_effort', label: t('usage.reasoningEffort'), sortable: false },
{ key: 'endpoint', label: t('usage.endpoint'), sortable: false }, { key: 'endpoint', label: t('usage.endpoint'), sortable: false },
{ key: 'stream', label: t('usage.type'), sortable: false }, { key: 'stream', label: t('usage.type'), sortable: false },
{ key: 'billing_mode', label: t('admin.usage.billingMode'), sortable: false },
{ key: 'tokens', label: t('usage.tokens'), sortable: false }, { key: 'tokens', label: t('usage.tokens'), sortable: false },
{ key: 'cost', label: t('usage.cost'), sortable: false }, { key: 'cost', label: t('usage.cost'), sortable: false },
{ key: 'first_token', label: t('usage.firstToken'), sortable: false }, { key: 'first_token', label: t('usage.firstToken'), sortable: false },
...@@ -615,6 +623,18 @@ const getRequestTypeBadgeClass = (log: UsageLog): string => { ...@@ -615,6 +623,18 @@ const getRequestTypeBadgeClass = (log: UsageLog): string => {
return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200' return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200'
} }
const getBillingModeLabel = (mode: string | null | undefined): string => {
if (mode === 'per_request') return t('admin.usage.billingModePerRequest')
if (mode === 'image') return t('admin.usage.billingModeImage')
return t('admin.usage.billingModeToken')
}
const getBillingModeBadgeClass = (mode: string | null | undefined): string => {
if (mode === 'per_request') return 'bg-blue-100 text-blue-800 dark:bg-blue-900/30 dark:text-blue-200'
if (mode === 'image') return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-200'
return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-300'
}
const getRequestTypeExportText = (log: UsageLog): string => { const getRequestTypeExportText = (log: UsageLog): string => {
const requestType = resolveUsageRequestType(log) const requestType = resolveUsageRequestType(log)
if (requestType === 'ws_v2') return 'WS' if (requestType === 'ws_v2') return 'WS'
...@@ -804,6 +824,7 @@ const exportToCSV = async () => { ...@@ -804,6 +824,7 @@ const exportToCSV = async () => {
'Reasoning Effort', 'Reasoning Effort',
'Inbound Endpoint', 'Inbound Endpoint',
'Type', 'Type',
'Billing Mode',
'Input Tokens', 'Input Tokens',
'Output Tokens', 'Output Tokens',
'Cache Read Tokens', 'Cache Read Tokens',
...@@ -822,6 +843,7 @@ const exportToCSV = async () => { ...@@ -822,6 +843,7 @@ const exportToCSV = async () => {
formatReasoningEffort(log.reasoning_effort), formatReasoningEffort(log.reasoning_effort),
log.inbound_endpoint || '', log.inbound_endpoint || '',
getRequestTypeExportText(log), getRequestTypeExportText(log),
getBillingModeLabel(log.billing_mode),
log.input_tokens, log.input_tokens,
log.output_tokens, log.output_tokens,
log.cache_read_tokens, log.cache_read_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment