Commit eb385457 authored by erio's avatar erio
Browse files

fix(channel): 全平台渠道映射覆盖 + 公共函数抽取 + 死代码清理

- 4个缺失handler入口添加渠道映射+限制检查(ChatCompletions/Responses/Gemini)
- 模型限制错误信息优化,区分"模型不可用"和"无账号"
- OpenAI RecordUsage RequestedModel 改用 OriginalModel
- ResolveChannelMappingAndRestrict/ReplaceModelInBody 抽取到 ChannelService 消除跨service重复
- validateNoDuplicateModels 按 platform:model 去重
- 删除 Channel.ResolveMappedModel 死代码和 CalculateCostWithChannel Deprecated方法
- 移除冗余nil检查,抽取 validatePricingBillingMode 公共校验
parent 4ea8b4cb
package admin package admin
import ( import (
"errors"
"strconv" "strconv"
"strings" "strings"
...@@ -224,6 +225,18 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe ...@@ -224,6 +225,18 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
return result return result
} }
// validatePricingBillingMode 校验按次/图片计费模式必须配置 PerRequestPrice 或 Intervals
func validatePricingBillingMode(pricing []service.ChannelModelPricing) error {
for _, p := range pricing {
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
return errors.New("Per-request price or intervals required for per_request/image billing mode")
}
}
}
return nil
}
// --- Handlers --- // --- Handlers ---
// List handles listing channels with pagination // List handles listing channels with pagination
...@@ -277,14 +290,10 @@ func (h *ChannelHandler) Create(c *gin.Context) { ...@@ -277,14 +290,10 @@ func (h *ChannelHandler) Create(c *gin.Context) {
} }
pricing := pricingRequestToService(req.ModelPricing) pricing := pricingRequestToService(req.ModelPricing)
for _, p := range pricing { if err := validatePricingBillingMode(pricing); err != nil {
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage { response.BadRequest(c, err.Error())
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
response.BadRequest(c, "Per-request price or intervals required for per_request/image billing mode")
return return
} }
}
}
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name, Name: req.Name,
...@@ -329,14 +338,10 @@ func (h *ChannelHandler) Update(c *gin.Context) { ...@@ -329,14 +338,10 @@ func (h *ChannelHandler) Update(c *gin.Context) {
} }
if req.ModelPricing != nil { if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing) pricing := pricingRequestToService(*req.ModelPricing)
for _, p := range pricing { if err := validatePricingBillingMode(pricing); err != nil {
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage { response.BadRequest(c, err.Error())
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
response.BadRequest(c, "Per-request price or intervals required for per_request/image billing mode")
return return
} }
}
}
input.ModelPricing = &pricing input.ModelPricing = &pricing
} }
......
...@@ -161,7 +161,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -161,7 +161,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 解析渠道级模型映射 + 限制检查 // 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted { if restricted {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return return
} }
......
...@@ -80,6 +80,13 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -80,6 +80,13 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return
}
// Claude Code only restriction // Claude Code only restriction
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly { if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error", h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
...@@ -203,7 +210,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -203,7 +210,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
// 5. Forward request // 5. Forward request
writerSizeBeforeForward := c.Writer.Size() writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq) forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
...@@ -255,6 +266,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -255,6 +266,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
reqLog.Error("gateway.cc.record_usage_failed", reqLog.Error("gateway.cc.record_usage_failed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
......
...@@ -80,6 +80,13 @@ func (h *GatewayHandler) Responses(c *gin.Context) { ...@@ -80,6 +80,13 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return
}
// Claude Code only restriction: // Claude Code only restriction:
// /v1/responses is never a Claude Code endpoint. // /v1/responses is never a Claude Code endpoint.
// When claude_code_only is enabled, this endpoint is rejected. // When claude_code_only is enabled, this endpoint is rejected.
...@@ -208,7 +215,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) { ...@@ -208,7 +215,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
// 5. Forward request // 5. Forward request
writerSizeBeforeForward := c.Writer.Size() writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq) forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, forwardBody, parsedReq)
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
...@@ -261,6 +272,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) { ...@@ -261,6 +272,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
reqLog.Error("gateway.responses.record_usage_failed", reqLog.Error("gateway.responses.record_usage_failed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
......
...@@ -184,6 +184,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -184,6 +184,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsRequestContext(c, modelName, stream, body) setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
// 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
if restricted {
googleError(c, http.StatusServiceUnavailable, "The requested model is not available for this API key")
return
}
reqModel := modelName // 保存映射前的原始模型名
if channelMapping.Mapped {
modelName = channelMapping.MappedModel
}
// Get subscription (may be nil) // Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c) subscription, _ := middleware.GetSubscriptionFromContext(c)
...@@ -523,6 +534,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -523,6 +534,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
LongContextMultiplier: 2.0, // 超出部分双倍计费 LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fs.ForceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.gemini_v1beta.models"), zap.String("component", "handler.gemini_v1beta.models"),
......
...@@ -79,6 +79,13 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -79,6 +79,13 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body) setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return
}
if h.errorPassthroughService != nil { if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService) service.BindErrorPassthroughService(c, h.errorPassthroughService)
} }
...@@ -183,7 +190,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -183,7 +190,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
forwardStart := time.Now() forwardStart := time.Now()
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model")) defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel) forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds() forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
...@@ -267,6 +278,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -267,6 +278,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"), zap.String("component", "handler.openai_gateway.chat_completions"),
......
...@@ -188,7 +188,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -188,7 +188,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 解析渠道级模型映射 + 限制检查 // 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted { if restricted {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return return
} }
...@@ -568,7 +568,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -568,7 +568,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
// 解析渠道级模型映射 + 限制检查 // 解析渠道级模型映射 + 限制检查
channelMappingMsg, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) channelMappingMsg, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted { if restricted {
h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return return
} }
......
...@@ -402,12 +402,6 @@ func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing ...@@ -402,12 +402,6 @@ func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing
return pricing, nil return pricing, nil
} }
// CalculateCostWithChannel 使用渠道定价计算费用
// Deprecated: 使用 CalculateCostUnified 代替
func (s *BillingService) CalculateCostWithChannel(model string, tokens UsageTokens, rateMultiplier float64, channelPricing *ChannelModelPricing) (*CostBreakdown, error) {
return s.calculateCostInternal(model, tokens, rateMultiplier, "", channelPricing)
}
// --- 统一计费入口 --- // --- 统一计费入口 ---
// CostInput 统一计费输入 // CostInput 统一计费输入
......
...@@ -82,38 +82,6 @@ type PricingInterval struct { ...@@ -82,38 +82,6 @@ type PricingInterval struct {
UpdatedAt time.Time UpdatedAt time.Time
} }
// ResolveMappedModel 解析渠道级模型映射,返回映射后的模型名。
// platform 指定查找哪个平台的映射规则。
// 支持通配符(如 "claude-*" → "claude-sonnet-4")。
// 如果没有匹配的映射规则,返回原始模型名。
func (c *Channel) ResolveMappedModel(platform, requestedModel string) string {
if len(c.ModelMapping) == 0 {
return requestedModel
}
platformMapping, ok := c.ModelMapping[platform]
if !ok || len(platformMapping) == 0 {
return requestedModel
}
lower := strings.ToLower(requestedModel)
// 精确匹配优先
for src, dst := range platformMapping {
if strings.ToLower(src) == lower {
return dst
}
}
// 通配符匹配
for src, dst := range platformMapping {
srcLower := strings.ToLower(src)
if strings.HasSuffix(srcLower, "*") {
prefix := strings.TrimSuffix(srcLower, "*")
if strings.HasPrefix(lower, prefix) {
return dst
}
}
}
return requestedModel
}
// IsActive 判断渠道是否启用 // IsActive 判断渠道是否启用
func (c *Channel) IsActive() bool { func (c *Channel) IsActive() bool {
return c.Status == StatusActive return c.Status == StatusActive
......
...@@ -11,6 +11,8 @@ import ( ...@@ -11,6 +11,8 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/sync/singleflight" "golang.org/x/sync/singleflight"
) )
...@@ -379,6 +381,34 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m ...@@ -379,6 +381,34 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m
return true return true
} }
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。
// 返回映射结果和是否被限制。groupID 为 nil 时跳过。
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
var mapping ChannelMappingResult
mapping.MappedModel = model
if groupID == nil {
return mapping, false
}
mapping = s.ResolveChannelMapping(ctx, *groupID, model)
restricted := s.IsModelRestricted(ctx, *groupID, mapping.MappedModel)
return mapping, restricted
}
// ReplaceModelInBody 替换请求体 JSON 中的 model 字段。
func ReplaceModelInBody(body []byte, newModel string) []byte {
if len(body) == 0 {
return body
}
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
return body
}
newBody, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
return body
}
return newBody
}
// --- CRUD --- // --- CRUD ---
// Create 创建渠道 // Create 创建渠道
...@@ -539,16 +569,16 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP ...@@ -539,16 +569,16 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP
return s.repo.List(ctx, params, status, search) return s.repo.List(ctx, params, status, search)
} }
// validateNoDuplicateModels 检查定价列表中是否有重复模型 // validateNoDuplicateModels 检查定价列表中是否有重复模型(同一平台下不允许重复)
func validateNoDuplicateModels(pricingList []ChannelModelPricing) error { func validateNoDuplicateModels(pricingList []ChannelModelPricing) error {
seen := make(map[string]bool) seen := make(map[string]bool)
for _, p := range pricingList { for _, p := range pricingList {
for _, model := range p.Models { for _, model := range p.Models {
lower := strings.ToLower(model) key := p.Platform + ":" + strings.ToLower(model)
if seen[lower] { if seen[key] {
return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries", model)) return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries for platform '%s'", model, p.Platform))
} }
seen[lower] = true seen[key] = true
} }
} }
return nil return nil
......
...@@ -872,17 +872,7 @@ type anthropicMetadataPayload struct { ...@@ -872,17 +872,7 @@ type anthropicMetadataPayload struct {
// replaceModelInBody 替换请求体中的model字段 // replaceModelInBody 替换请求体中的model字段
// 优先使用定点修改,尽量保持客户端原始字段顺序。 // 优先使用定点修改,尽量保持客户端原始字段顺序。
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
if len(body) == 0 { return ReplaceModelInBody(body, newModel)
return body
}
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
return body
}
newBody, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
return body
}
return newBody
} }
type claudeOAuthNormalizeOptions struct { type claudeOAuthNormalizeOptions struct {
...@@ -7794,11 +7784,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7794,11 +7784,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} }
var err error var err error
if s.resolver != nil && apiKey.Group != nil { if s.resolver != nil && apiKey.Group != nil {
var groupID *int64
if apiKey.Group != nil {
gid := apiKey.Group.ID gid := apiKey.Group.ID
groupID = &gid groupID := &gid
}
cost, err = s.billingService.CalculateCostUnified(CostInput{ cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx, Ctx: ctx,
Model: billingModel, Model: billingModel,
...@@ -8184,7 +8171,7 @@ func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int6 ...@@ -8184,7 +8171,7 @@ func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int6
// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用) // ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用)
func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
return s.replaceModelInBody(body, newModel) return ReplaceModelInBody(body, newModel)
} }
// IsModelRestricted 检查模型是否被渠道限制 // IsModelRestricted 检查模型是否被渠道限制
...@@ -8198,14 +8185,10 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m ...@@ -8198,14 +8185,10 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。 // ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。
// 返回映射结果和是否被限制。 // 返回映射结果和是否被限制。
func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
var mapping ChannelMappingResult if s.channelService == nil {
mapping.MappedModel = model return ChannelMappingResult{MappedModel: model}, false
if groupID == nil {
return mapping, false
} }
mapping = s.ResolveChannelMapping(ctx, *groupID, model) return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
restricted := s.IsModelRestricted(ctx, *groupID, mapping.MappedModel)
return mapping, restricted
} }
// ForwardCountTokens 转发 count_tokens 请求到上游 API // ForwardCountTokens 转发 count_tokens 请求到上游 API
......
...@@ -416,29 +416,15 @@ func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID in ...@@ -416,29 +416,15 @@ func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID in
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。 // ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。
// 返回映射结果和是否被限制。 // 返回映射结果和是否被限制。
func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
var mapping ChannelMappingResult if s.channelService == nil {
mapping.MappedModel = model return ChannelMappingResult{MappedModel: model}, false
if groupID == nil {
return mapping, false
} }
mapping = s.ResolveChannelMapping(ctx, *groupID, model) return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
restricted := s.IsModelRestricted(ctx, *groupID, mapping.MappedModel)
return mapping, restricted
} }
// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。 // ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。
func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
if len(body) == 0 { return ReplaceModelInBody(body, newModel)
return body
}
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
return body
}
newBody, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
return body
}
return newBody
} }
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
...@@ -4249,13 +4235,20 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ...@@ -4249,13 +4235,20 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier() accountRateMultiplier := account.BillingRateMultiplier()
requestID := resolveUsageBillingRequestID(ctx, result.RequestID) requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
// 确定 RequestedModel(渠道映射前的原始模型)
requestedModel := result.Model
if input.OriginalModel != "" {
requestedModel = input.OriginalModel
}
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: requestID, RequestID: requestID,
Model: result.Model, Model: result.Model,
RequestedModel: result.Model, RequestedModel: requestedModel,
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
ServiceTier: result.ServiceTier, ServiceTier: result.ServiceTier,
ReasoningEffort: result.ReasoningEffort, ReasoningEffort: result.ReasoningEffort,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment