Commit ce41afb7 authored by erio's avatar erio
Browse files

refactor: move channel model restriction from handler to scheduling phase

Move the model pricing restriction check from 8 handler entry points
to the account scheduling phase (SelectAccountForModelWithExclusions /
SelectAccountWithLoadAwareness), aligning restriction with billing:

- requested: check original request model against pricing list
- channel_mapped: check channel-mapped model against pricing list
- upstream: per-account check using account-mapped model

Handler layer now only resolves channel mapping (no restriction).
Scheduling layer performs pre-check for requested/channel_mapped,
and per-account filtering for upstream billing source.
parent b4a42a64
......@@ -158,12 +158,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqStream := parsedReq.Stream
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
// 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return
}
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
......
......@@ -81,11 +81,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return
}
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
......
......@@ -81,11 +81,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return
}
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction:
// /v1/responses is never a Claude Code endpoint.
......
......@@ -185,11 +185,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
// 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
if restricted {
googleError(c, http.StatusServiceUnavailable, "The requested model is not available for this API key")
return
}
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
reqModel := modelName // 保存映射前的原始模型名
if channelMapping.Mapped {
modelName = channelMapping.MappedModel
......
......@@ -80,11 +80,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return
}
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
......
......@@ -185,12 +185,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return
}
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
......@@ -562,12 +558,8 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射 + 限制检查
channelMappingMsg, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if restricted {
h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
return
}
// 解析渠道级模型映射
channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
......@@ -1128,11 +1120,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
// 解析渠道级模型映射 + 限制检查
channelMappingWS, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
if restricted {
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model not allowed")
return
}
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
var currentUserRelease func()
var currentAccountRelease func()
......
......@@ -436,8 +436,9 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m
return checkRestricted(lk, groupID, model)
}
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。
// 返回映射结果和是否被限制。groupID 为 nil 时跳过。
// ResolveChannelMappingAndRestrict 解析渠道映射。
// 返回映射结果。模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction),
// restricted 始终返回 false,保留签名兼容性。
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if groupID == nil {
return ChannelMappingResult{MappedModel: model}, false
......@@ -446,10 +447,7 @@ func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, g
if lk == nil {
return ChannelMappingResult{MappedModel: model}, false
}
// 先用原始模型检查定价列表限制,再做映射
restricted := checkRestricted(lk, *groupID, model)
mapping := resolveMapping(lk, *groupID, model)
return mapping, restricted
return resolveMapping(lk, *groupID, model), false
}
// resolveMapping 基于已查找的渠道信息解析模型映射
......
......@@ -1068,6 +1068,8 @@ func TestIsModelRestricted_CaseInsensitive(t *testing.T) {
}
// --- 4.5 ResolveChannelMappingAndRestrict ---
// 注意:模型限制检查已移至调度阶段(GatewayService.checkChannelPricingRestriction),
// ResolveChannelMappingAndRestrict 仅做映射,restricted 始终为 false。
func TestResolveChannelMappingAndRestrict_NilGroupID(t *testing.T) {
repo := &mockChannelRepository{
......@@ -1083,7 +1085,7 @@ func TestResolveChannelMappingAndRestrict_NilGroupID(t *testing.T) {
require.Equal(t, "claude-opus-4", mapping.MappedModel)
}
func TestResolveChannelMappingAndRestrict_ModelInPricing_WithMapping(t *testing.T) {
func TestResolveChannelMappingAndRestrict_WithMapping(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
......@@ -1103,41 +1105,12 @@ func TestResolveChannelMappingAndRestrict_ModelInPricing_WithMapping(t *testing.
gid := int64(10)
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "claude-sonnet-4")
require.False(t, restricted) // model IS in pricing
require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段
require.True(t, mapping.Mapped)
require.Equal(t, "claude-sonnet-4-20250514", mapping.MappedModel)
}
func TestResolveChannelMappingAndRestrict_ModelNotInPricing_WithMapping(t *testing.T) {
// CRITICAL: this test verifies that restriction checks the ORIGINAL model
// against pricing BEFORE applying mapping. The model "unknown-model" is NOT
// in pricing, so even though the wildcard mapping "*" matches it, it should
// still be restricted.
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {
"*": "catch-all-target",
},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"})
svc := newTestChannelService(repo)
gid := int64(10)
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model")
require.True(t, restricted) // model NOT in pricing, even though mapping exists
require.True(t, mapping.Mapped)
require.Equal(t, "catch-all-target", mapping.MappedModel)
}
func TestResolveChannelMappingAndRestrict_ModelNotInPricing_NoMapping(t *testing.T) {
func TestResolveChannelMappingAndRestrict_NoMapping(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
......@@ -1152,7 +1125,7 @@ func TestResolveChannelMappingAndRestrict_ModelNotInPricing_NoMapping(t *testing
gid := int64(10)
mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model")
require.True(t, restricted) // model NOT in pricing
require.False(t, restricted) // restricted 始终为 false,限制检查在调度阶段
require.False(t, mapping.Mapped)
require.Equal(t, "unknown-model", mapping.MappedModel)
}
......
......@@ -1178,6 +1178,11 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 渠道定价限制预检查(requested / channel_mapped 基准)
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
......@@ -1208,8 +1213,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
// sub2apiUserID: 系统用户 ID,用于二维亲和调度
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
// 渠道定价限制预检查(requested / channel_mapped 基准)
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs {
......@@ -2955,6 +2967,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持)
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
acc := &accounts[i]
......@@ -2975,6 +2988,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
......@@ -3207,6 +3223,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
acc := &accounts[i]
......@@ -3231,6 +3248,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
......@@ -8212,6 +8232,67 @@ func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, g
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
}
// checkChannelPricingRestriction 根据渠道计费基准检查模型是否受定价列表限制。
// 供调度阶段预检查(requested / channel_mapped)。
// upstream 需逐账号检查,此处返回 false。
func (s *GatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
if groupID == nil || s.channelService == nil || requestedModel == "" {
return false
}
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
if billingModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
}
// billingModelForRestriction 根据计费基准确定限制检查使用的模型。
// upstream 返回空(需逐账号检查)。
func billingModelForRestriction(source, requestedModel, channelMappedModel string) string {
switch source {
case BillingModelSourceRequested:
return requestedModel
case BillingModelSourceUpstream:
return ""
default: // channel_mapped
return channelMappedModel
}
}
// isUpstreamModelRestrictedByChannel 检查账号映射后的上游模型是否受渠道定价限制。
// 仅在 BillingModelSource="upstream" 且 RestrictModels=true 时由调度循环调用。
func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
if s.channelService == nil {
return false
}
upstreamModel := resolveAccountUpstreamModel(account, requestedModel)
if upstreamModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
}
// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。
func resolveAccountUpstreamModel(account *Account, requestedModel string) string {
if account.Platform == PlatformAntigravity {
return mapAntigravityModel(account, requestedModel)
}
return account.GetMappedModel(requestedModel)
}
// needsUpstreamChannelRestrictionCheck 判断是否需要在调度循环中逐账号检查上游模型的渠道限制。
func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil || ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment