Commit 2dce4306 authored by erio's avatar erio
Browse files

refactor: move channel model restriction from handler to scheduling phase

Move the model pricing restriction check from 8 handler entry points
to the account scheduling phase (SelectAccountForModelWithExclusions /
SelectAccountWithLoadAwareness), aligning restriction with billing:

- requested: check original request model against pricing list
- channel_mapped: check channel-mapped model against pricing list
- upstream: per-account check using account-mapped model

Handler layer now only resolves channel mapping (no restriction).
Scheduling layer performs pre-check for requested/channel_mapped,
and per-account filtering for upstream billing source.
parent 3de77130
......@@ -80,7 +80,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
// 解析渠道级模型映射 + 限制检查
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction
......
......@@ -80,7 +80,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
// 解析渠道级模型映射 + 限制检查
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
// Claude Code only restriction:
......
......@@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
googleError(c, http.StatusBadGateway, err.Error())
return
}
if shouldFallbackGeminiModel(modelName, res) {
if shouldFallbackGeminiModels(res) {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
......@@ -184,7 +184,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
// 解析渠道级模型映射
// 解析渠道级模型映射 + 限制检查
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
reqModel := modelName // 保存映射前的原始模型名
if channelMapping.Mapped {
......@@ -682,16 +682,6 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
return false
}
func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool {
if shouldFallbackGeminiModels(res) {
return true
}
if res == nil || res.StatusCode != http.StatusNotFound {
return false
}
return gemini.HasFallbackModel(modelName)
}
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
//
......
......@@ -79,7 +79,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
// 解析渠道级模型映射 + 限制检查
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
if h.errorPassthroughService != nil {
......
......@@ -47,13 +47,6 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
}
func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
if apiKey == nil || apiKey.Group == nil {
return ""
}
return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel))
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
......@@ -557,8 +550,6 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
return
}
reqModel := modelResult.String()
routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel)
preferredMappedModel := resolveOpenAIMessagesDispatchMappedModel(apiKey, reqModel)
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
......@@ -617,20 +608,17 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
effectiveMappedModel := preferredMappedModel
for {
currentRoutingModel := routingModel
if effectiveMappedModel != "" {
currentRoutingModel = effectiveMappedModel
}
// 清除上一次迭代的降级模型标记,避免残留影响本次迭代
c.Set("openai_messages_fallback_model", "")
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"", // no previous_response_id
sessionHash,
currentRoutingModel,
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
......@@ -639,7 +627,29 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
// 首次调度失败 + 有默认映射模型 → 用默认模型重试
if len(failedAccountIDs) == 0 {
defaultModel := ""
if apiKey.Group != nil {
defaultModel = apiKey.Group.DefaultMappedModel
}
if defaultModel != "" && defaultModel != reqModel {
reqLog.Info("openai_messages.fallback_to_default_model",
zap.String("default_mapped_model", defaultModel),
)
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"",
sessionHash,
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
)
if err == nil && selection != nil {
c.Set("openai_messages_fallback_model", defaultModel)
}
}
if err != nil {
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
......@@ -671,7 +681,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
defaultMappedModel := strings.TrimSpace(effectiveMappedModel)
// Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
// Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
// 应用渠道模型映射到请求体
forwardBody := body
if channelMappingMsg.Mapped {
......@@ -1106,7 +1118,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
// 解析渠道级模型映射
// 解析渠道级模型映射 + 限制检查
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
var currentUserRelease func()
......
......@@ -12,7 +12,6 @@ import (
"log/slog"
mathrand "math/rand"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
......@@ -42,7 +41,8 @@ import (
const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
stickySessionTTL = time.Hour // 粘性会话TTL
ClientAffinityTTL = 24 * time.Hour // 客户端亲和TTL
defaultMaxLineSize = 500 * 1024 * 1024
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
// to match real Claude CLI traffic as closely as possible. When we need a visual
......@@ -60,14 +60,28 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
// MediaType 媒体类型常量
const (
MediaTypeImage = "image"
MediaTypeVideo = "video"
MediaTypePrompt = "prompt"
)
const (
claudeMaxMessageOverheadTokens = 3
claudeMaxBlockOverheadTokens = 1
claudeMaxUnknownContentTokens = 4
)
// ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{}
// accountWithLoad 账号与负载信息的组合,用于负载感知调度
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
account *Account
loadInfo *AccountLoadInfo
affinityCount int64 // 亲和客户端数量(反向索引),越少越优先
}
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
......@@ -331,6 +345,10 @@ var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
// clientIDFromMetadataRegex 从 metadata.user_id 中提取客户端 ID(64位 hex)
// 格式: user_{64位hex}_account_...
clientIDFromMetadataRegex = regexp.MustCompile(`^user_([a-f0-9]{64})_account_`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
// 注意:前缀之间不应存在包含关系,否则会导致冗余匹配
......@@ -348,6 +366,12 @@ var ErrNoAvailableAccounts = errors.New("no available accounts")
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// ErrAffinityNoSwitch 表示亲和账号不可用且不允许切换到其他账号
var ErrAffinityNoSwitch = errors.New("affinity account unavailable and switching is disabled")
// ErrAffinityLimitExceeded 表示亲和客户端限制已达上限
var ErrAffinityLimitExceeded = errors.New("affinity client limit exceeded")
// allowedHeaders 白名单headers(参考CRS项目)
var allowedHeaders = map[string]bool{
"accept": true,
......@@ -369,8 +393,6 @@ var allowedHeaders = map[string]bool{
"user-agent": true,
"content-type": true,
"accept-encoding": true,
"x-claude-code-session-id": true,
"x-client-request-id": true,
}
// GatewayCache 定义网关服务的缓存操作接口。
......@@ -391,6 +413,39 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
// GetAffinityAccounts 获取亲和账号列表(按最近使用降序),同时清理过期成员
GetAffinityAccounts(ctx context.Context, groupID int64, userID int64, clientID string, ttl time.Duration) ([]int64, error)
// UpdateAffinity 添加/更新亲和关系(更新 score 为当前时间戳,刷新 key TTL)
UpdateAffinity(ctx context.Context, groupID int64, userID int64, clientID string, accountID int64, ttl time.Duration) error
// GetAccountAffinityCountBatch 批量获取账号的亲和成员数量(惰性清理过期成员)
GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error)
// GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和成员列表(去重)
// accountGroups: map[accountID][]groupID
// 返回值成员格式为 {userID}/{clientID}
GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error)
// GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间)
GetAccountAffinityClientsWithScores(ctx context.Context, accountID int64, groupIDs []int64, ttl time.Duration) ([]AffinityClient, error)
// ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)
// 用于账号关闭亲和时立即清理旧绑定
ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error
// GetAffinityMultiCount 获取账号的多维度亲和计数
// 返回: uniqueUsers, uniqueClients, perUserClients
GetAffinityMultiCount(ctx context.Context, groupID int64, accountID int64, targetUserID int64, ttl time.Duration) (users, clients, perUser int64, err error)
}
// AffinityClient 亲和客户端信息(含用户 ID 和最后活跃时间)
type AffinityClient struct {
UserID int64 `json:"user_id"`
ClientID string `json:"client_id"`
LastActive time.Time `json:"last_active"`
}
// SortAffinityClients 按最后活跃时间降序排序
func SortAffinityClients(clients []AffinityClient) {
sort.Slice(clients, func(i, j int) bool {
return clients[i].LastActive.After(clients[j].LastActive)
})
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
......@@ -461,6 +516,20 @@ func shouldClearStickySession(account *Account, requestedModel string) bool {
return false
}
// extractClientIDFromMetadata 从 metadata.user_id 中提取客户端 ID(64位 hex)。
// 格式: user_{64位hex}_account_..._session_...
// 返回空字符串表示无法提取(非 Claude Code/Console 客户端)。
func extractClientIDFromMetadata(metadataUserID string) string {
if metadataUserID == "" {
return ""
}
matches := clientIDFromMetadataRegex.FindStringSubmatch(metadataUserID)
if matches == nil {
return ""
}
return matches[1]
}
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
......@@ -504,6 +573,9 @@ type ForwardResult struct {
ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K"
// Sora 媒体字段
MediaType string // image / video / prompt
MediaURL string // 生成后的媒体地址(可选)
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
......@@ -1162,6 +1234,11 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 渠道定价限制预检查(requested / channel_mapped 基准)
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
......@@ -1180,32 +1257,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
platform = PlatformAnthropic
}
// Claude Code 限制可能已将 groupID 解析为 fallback group,
// 渠道限制预检查必须使用解析后的分组。
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err != nil {
return nil, err
}
return s.hydrateSelectedAccount(ctx, account)
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// antigravity 分组、强制平台模式或无分组使用单平台选择
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err != nil {
return nil, err
}
return s.hydrateSelectedAccount(ctx, account)
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
......@@ -1213,6 +1273,11 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
// sub2apiUserID: 系统用户 ID,用于二维亲和调度
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
// 渠道定价限制预检查(requested / channel_mapped 基准)
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs {
......@@ -1233,15 +1298,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
ctx = s.withGroupContext(ctx, group)
// Claude Code 限制可能已将 groupID 解析为 fallback group,
// 渠道限制预检查必须使用解析后的分组。
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
slog.Warn("channel pricing restriction blocked request",
"group_id", derefGroupID(groupID),
"model", requestedModel)
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
var stickyAccountID int64
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
stickyAccountID = prefetch
......@@ -1251,6 +1307,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// 提取客户端 ID(用于客户端亲和调度)
affinityClientID := extractClientIDFromMetadata(metadataUserID)
affinityUserID := sub2apiUserID
if s.debugModelRoutingEnabled() && requestedModel != "" {
groupPlatform := ""
if group != nil {
......@@ -1272,6 +1332,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if err != nil {
return nil, err
}
if shouldFilterAccountWithoutClientID(account, affinityClientID) {
localExcluded[account.ID] = struct{}{}
continue
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
......@@ -1281,7 +1345,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
localExcluded[account.ID] = struct{}{} // 排除此账号
continue // 重新选择
}
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
// 对于等待计划的情况,也需要先检查会话限制
......@@ -1293,20 +1361,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
})
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
})
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
}
......@@ -1323,12 +1397,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if err != nil {
return nil, err
}
accounts = filterAccountsWithoutClientID(accounts, affinityClientID)
if len(accounts) == 0 {
return nil, ErrNoAvailableAccounts
}
ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts)
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
......@@ -1336,12 +1416,19 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
_, excluded := excludedIDs[accountID]
return excluded
}
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
affinityFlow := newGatewayAffinityFlow(
s,
ctx,
groupID,
sessionHash,
requestedModel,
affinityClientID,
affinityUserID,
platform,
useMixed,
accountByID,
isExcluded,
)
// 获取模型路由配置(仅 anthropic 平台)
var routingAccountIDs []int64
......@@ -1430,76 +1517,53 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
// 粘性账号在路由列表中,优先使用
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
var stickyCacheMissReason string
gatePass := s.isAccountSchedulableForSelection(stickyAccount) &&
if s.isAccountSchedulableForSelection(stickyAccount) &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
s.isAccountSchedulableForQuota(stickyAccount) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true)
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true)
if rpmPass { // 粘性会话窗口费用+RPM 检查
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
result.ReleaseFunc() // 释放槽位
stickyCacheMissReason = "session_limit"
// 继续到负载感知选择
} else {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
}
return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil)
return &AccountSelectionResult{
Account: stickyAccount,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
if stickyCacheMissReason == "" {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
stickyCacheMissReason = "session_limit"
// 会话限制已满,继续到负载感知选择
} else {
return &AccountSelectionResult{
Account: stickyAccount,
WaitPlan: &AccountWaitPlan{
AccountID: stickyAccountID,
MaxConcurrency: stickyAccount.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
// 会话限制已满,继续到负载感知选择
} else {
stickyCacheMissReason = "wait_queue_full"
return &AccountSelectionResult{
Account: stickyAccount,
WaitPlan: &AccountWaitPlan{
AccountID: stickyAccountID,
MaxConcurrency: stickyAccount.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
} else if !gatePass {
stickyCacheMissReason = "gate_check"
} else {
stickyCacheMissReason = "rpm_red"
}
// 记录粘性缓存未命中的结构化日志
if stickyCacheMissReason != "" {
baseRPM := stickyAccount.GetBaseRPM()
var currentRPM int
if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok {
currentRPM = count
}
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d",
stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM)
}
} else {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0",
stickyAccountID, shortSessionHash(sessionHash))
}
}
}
......@@ -1527,7 +1591,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
if len(routingAvailable) > 0 {
// 排序:优先级 > 负载率 > 最后使用时间
// 批量获取亲和客户端数量
s.populateAffinityCounts(ctx, routingAvailable, derefGroupID(groupID))
// 排序:优先级 > 负载率 > 亲和客户端数 > 最后使用时间
sort.SliceStable(routingAvailable, func(i, j int) bool {
a, b := routingAvailable[i], routingAvailable[j]
if a.account.Priority != b.account.Priority {
......@@ -1536,6 +1603,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
if a.affinityCount != b.affinityCount {
return a.affinityCount < b.affinityCount
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
......@@ -1561,10 +1631,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && item.account.IsAffinityEnabled() {
_ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, item.account.ID, ClientAffinityTTL)
}
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil)
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
......@@ -1577,12 +1654,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{
AccountID: item.account.ID,
MaxConcurrency: item.account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
})
return &AccountSelectionResult{
Account: item.account,
WaitPlan: &AccountWaitPlan{
AccountID: item.account.ID,
MaxConcurrency: item.account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
}
......@@ -1591,14 +1671,27 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) {
// ============ Layer 1.3: 用户亲和预处理(pinned_users 自动注入) ============
affinityFlow.preprocessPinnedUsers(accounts)
// ============ Layer 1.4: 客户端亲和调度(优先于粘性会话) ============
affinityHit := false
if affinityResult, hit, err := affinityFlow.trySelectAffinityAccount(); err != nil {
return nil, err
} else {
affinityHit = hit
if affinityResult != nil {
return affinityResult, nil
}
}
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置 且 亲和未命中时生效) ============
if !affinityHit && len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) {
accountID := stickyAccountID
if accountID > 0 && !isExcluded(accountID) {
account, ok := accountByID[accountID]
if ok {
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
......@@ -1614,31 +1707,32 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
// Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
if s.cache != nil {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
}
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
// Session count limit check (wait plan also requires session quota)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
} else {
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
})
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
......@@ -1697,9 +1791,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil {
return nil, legacyErr
} else if ok {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && result.Account != nil && result.Account.IsAffinityEnabled() {
_ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, result.Account.ID, ClientAffinityTTL)
}
return result, nil
}
} else {
......@@ -1717,13 +1812,37 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// 分层过滤选择:优先级 → 负载率 → LRU
// 批量获取亲和客户端数量(用于均衡分配新客户端)
s.populateAffinityCounts(ctx, available, derefGroupID(groupID))
// 分层过滤选择:优先级 → 亲和三区 → 负载率 → 亲和客户端数 → LRU
for len(available) > 0 {
// 1. 取优先级最小的集合
candidates := filterByMinPriority(available)
// 2. 取负载率最低的集合
// 2. 按亲和三区过滤:绿区优先 → 黄区降级 → 红区移除(在同优先级内)
candidates = classifyByAffinityZone(candidates)
if len(candidates) == 0 {
// 当前优先级组全部在红区,移除后回退到下一优先级组
minPri := available[0].account.Priority
for _, a := range available[1:] {
if a.account.Priority < minPri {
minPri = a.account.Priority
}
}
newAvailable := make([]accountWithLoad, 0, len(available))
for _, a := range available {
if a.account.Priority != minPri {
newAvailable = append(newAvailable, a)
}
}
available = newAvailable
continue
}
// 3. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates)
// 3. LRU 选择最久未用的账号
// 3. 取亲和客户端数最少的集合
candidates = filterByMinAffinityCount(candidates)
// 4. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth)
if selected == nil {
break
......@@ -1738,7 +1857,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
}
return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil)
// 更新亲和关系
if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && selected.account.IsAffinityEnabled() {
_ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, selected.account.ID, ClientAffinityTTL)
}
return &AccountSelectionResult{
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
......@@ -1761,17 +1888,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
continue // 会话限制已满,尝试下一个账号
}
return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
})
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return nil, ErrNoAvailableAccounts
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) {
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
......@@ -1786,15 +1916,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
}
selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil)
if err != nil {
return nil, false, err
}
return selection, true, nil
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, true
}
}
return nil, false, nil
return nil, false
}
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
......@@ -1939,6 +2069,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
if platform == PlatformSora {
return s.listSoraSchedulableAccounts(ctx, groupID)
}
if s.schedulerSnapshot != nil {
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil {
......@@ -2035,6 +2168,53 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
return accounts, useMixed, nil
}
func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) {
const useMixed = false
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
} else if groupID != nil {
accounts, err = s.accountRepo.ListByGroup(ctx, *groupID)
} else {
accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
}
if err != nil {
slog.Debug("account_scheduling_list_failed",
"group_id", derefGroupID(groupID),
"platform", PlatformSora,
"error", err)
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform != PlatformSora {
continue
}
if !s.isSoraAccountSchedulable(&acc) {
continue
}
filtered = append(filtered, acc)
}
slog.Debug("account_scheduling_list_sora",
"group_id", derefGroupID(groupID),
"platform", PlatformSora,
"raw_count", len(accounts),
"filtered_count", len(filtered))
for _, acc := range filtered {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
return filtered, useMixed, nil
}
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
......@@ -2059,10 +2239,33 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
return account.Platform == platform
}
func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool {
return s.soraUnschedulableReason(account) == ""
}
func (s *GatewayService) soraUnschedulableReason(account *Account) string {
if account == nil {
return "account_nil"
}
if account.Status != StatusActive {
return fmt.Sprintf("status=%s", account.Status)
}
if !account.Schedulable {
return "schedulable=false"
}
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339))
}
return ""
}
func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool {
if account == nil {
return false
}
if account.Platform == PlatformSora {
return s.isSoraAccountSchedulable(account)
}
return account.IsSchedulable()
}
......@@ -2070,6 +2273,12 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
if account == nil {
return false
}
if account.Platform == PlatformSora {
if !s.isSoraAccountSchedulable(account) {
return false
}
return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0
}
return account.IsSchedulableForModelWithContext(ctx, requestedModel)
}
......@@ -2409,31 +2618,34 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return s.accountRepo.GetByID(ctx, accountID)
}
func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) {
if account == nil || s.schedulerSnapshot == nil {
return account, nil
// populateAffinityCounts 批量获取账号的亲和客户端数量并填入 accountWithLoad 切片。
// 仅当存在开启了客户端亲和的账号时才查询 Redis,否则跳过。
func (s *GatewayService) populateAffinityCounts(ctx context.Context, accounts []accountWithLoad, groupID int64) {
if s.cache == nil || len(accounts) == 0 {
return
}
hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID)
if err != nil {
return nil, err
// 快速检查:是否有任何账号开启了亲和
hasAffinity := false
for _, acc := range accounts {
if acc.account.IsAffinityEnabled() {
hasAffinity = true
break
}
}
if hydrated == nil {
return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID)
if !hasAffinity {
return
}
return hydrated, nil
}
func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) {
hydrated, err := s.hydrateSelectedAccount(ctx, account)
accountIDs := make([]int64, len(accounts))
for i, acc := range accounts {
accountIDs[i] = acc.account.ID
}
countMap, err := s.cache.GetAccountAffinityCountBatch(ctx, groupID, accountIDs, ClientAffinityTTL)
if err != nil {
return nil, err
return // 查询失败不影响调度,affinityCount 保持 0
}
for i := range accounts {
accounts[i].affinityCount = countMap[accounts[i].account.ID]
}
return &AccountSelectionResult{
Account: hydrated,
Acquired: acquired,
ReleaseFunc: release,
WaitPlan: waitPlan,
}, nil
}
// filterByMinPriority 过滤出优先级最小的账号集合
......@@ -2476,6 +2688,64 @@ func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad {
return result
}
// filterByMinAffinityCount 过滤出亲和客户端数最少的账号集合
func filterByMinAffinityCount(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
return accounts
}
minCount := accounts[0].affinityCount
for _, acc := range accounts[1:] {
if acc.affinityCount < minCount {
minCount = acc.affinityCount
}
}
result := make([]accountWithLoad, 0, len(accounts))
for _, acc := range accounts {
if acc.affinityCount == minCount {
result = append(result, acc)
}
}
return result
}
// classifyByAffinityZone 按亲和分区对候选账号进行分类。
// 返回值:仅绿区账号(有绿区时),否则返回黄区账号。红区账号被移除。
// 如果没有任何账号开启了亲和三区配置(即 affinity_base <= 0),则原样返回所有账号。
func classifyByAffinityZone(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
return accounts
}
// 快速检查:是否有任何账号配置了 affinity_base
hasZoneConfig := false
for _, acc := range accounts {
if acc.account.IsAffinityEnabled() && acc.account.GetAffinityBase() > 0 {
hasZoneConfig = true
break
}
}
if !hasZoneConfig {
return accounts
}
greens := make([]accountWithLoad, 0, len(accounts))
yellows := make([]accountWithLoad, 0, len(accounts))
for _, acc := range accounts {
zone := acc.account.GetAffinityZone(acc.affinityCount)
switch zone {
case AffinityZoneGreen:
greens = append(greens, acc)
case AffinityZoneYellow:
yellows = append(yellows, acc)
case AffinityZoneRed:
// 红区:移除,不参与调度
}
}
if len(greens) > 0 {
return greens
}
return yellows
}
// selectByLRU 从集合中选择最久未用的账号
// 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个
func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad {
......@@ -2711,12 +2981,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
// require_privacy_set: 获取分组信息
var schedGroup *Group
if groupID != nil && s.groupRepo != nil {
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
}
var accounts []Account
accountsLoaded := false
......@@ -2788,12 +3052,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
......@@ -2885,8 +3143,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持)
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查,
// 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
......@@ -2899,12 +3155,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue
}
......@@ -2971,12 +3221,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
preferOAuth := nativePlatform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
// require_privacy_set: 获取分组信息
var schedGroup *Group
if groupID != nil && s.groupRepo != nil {
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
}
var accounts []Account
accountsLoaded := false
......@@ -3044,12 +3288,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
......@@ -3143,7 +3381,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account
for i := range accounts {
......@@ -3156,12 +3393,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
......@@ -3273,6 +3504,9 @@ func (s *GatewayService) logDetailedSelectionFailure(
stats.SampleMappingIDs,
stats.SampleRateLimitIDs,
)
if platform == PlatformSora {
s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling)
}
return stats
}
......@@ -3329,7 +3563,11 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "excluded"}
}
if !s.isAccountSchedulableForSelection(acc) {
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
detail := "generic_unschedulable"
if acc.Platform == PlatformSora {
detail = s.soraUnschedulableReason(acc)
}
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
}
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
return selectionFailureDiagnosis{
......@@ -3353,6 +3591,57 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "eligible"}
}
func (s *GatewayService) logSoraSelectionFailureDetails(
ctx context.Context,
groupID *int64,
sessionHash string,
requestedModel string,
accounts []Account,
excludedIDs map[int64]struct{},
allowMixedScheduling bool,
) {
const maxLines = 30
logged := 0
for i := range accounts {
if logged >= maxLines {
break
}
acc := &accounts[i]
diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling)
if diagnosis.Category == "eligible" {
continue
}
detail := diagnosis.Detail
if detail == "" {
detail = "-"
}
logger.LegacyPrintf(
"service.gateway",
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s",
derefGroupID(groupID),
requestedModel,
shortSessionHash(sessionHash),
acc.ID,
acc.Platform,
diagnosis.Category,
detail,
)
logged++
}
if len(accounts) > maxLines {
logger.LegacyPrintf(
"service.gateway",
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d",
derefGroupID(groupID),
requestedModel,
shortSessionHash(sessionHash),
len(accounts),
logged,
)
}
}
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
if acc == nil {
return true
......@@ -3431,10 +3720,17 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
return mapAntigravityModel(account, requestedModel) != ""
}
if account.Platform == PlatformSora {
return s.isSoraModelSupportedByAccount(account, requestedModel)
}
if account.IsBedrock() {
_, ok := ResolveBedrockModelID(account, requestedModel)
return ok
}
// OpenAI 透传模式:仅替换认证,允许所有模型
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
return true
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel)
......@@ -3443,6 +3739,143 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
return account.IsModelSupported(requestedModel)
}
func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool {
if account == nil {
return false
}
if strings.TrimSpace(requestedModel) == "" {
return true
}
// 先走原始精确/通配符匹配。
mapping := account.GetModelMapping()
if len(mapping) == 0 || account.IsModelSupported(requestedModel) {
return true
}
aliases := buildSoraModelAliases(requestedModel)
if len(aliases) == 0 {
return false
}
hasSoraSelector := false
for pattern := range mapping {
if !isSoraModelSelector(pattern) {
continue
}
hasSoraSelector = true
if matchPatternAnyAlias(pattern, aliases) {
return true
}
}
// 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*),
// 此时不应误拦截 Sora 模型请求。
if !hasSoraSelector {
return true
}
return false
}
func matchPatternAnyAlias(pattern string, aliases []string) bool {
normalizedPattern := strings.ToLower(strings.TrimSpace(pattern))
if normalizedPattern == "" {
return false
}
for _, alias := range aliases {
if matchWildcard(normalizedPattern, alias) {
return true
}
}
return false
}
func isSoraModelSelector(pattern string) bool {
p := strings.ToLower(strings.TrimSpace(pattern))
if p == "" {
return false
}
switch {
case strings.HasPrefix(p, "sora"),
strings.HasPrefix(p, "gpt-image"),
strings.HasPrefix(p, "prompt-enhance"),
strings.HasPrefix(p, "sy_"):
return true
}
return p == "video" || p == "image"
}
func buildSoraModelAliases(requestedModel string) []string {
modelID := strings.ToLower(strings.TrimSpace(requestedModel))
if modelID == "" {
return nil
}
aliases := make([]string, 0, 8)
addAlias := func(value string) {
v := strings.ToLower(strings.TrimSpace(value))
if v == "" {
return
}
for _, existing := range aliases {
if existing == v {
return
}
}
aliases = append(aliases, v)
}
addAlias(modelID)
cfg, ok := GetSoraModelConfig(modelID)
if ok {
addAlias(cfg.Model)
switch cfg.Type {
case "video":
addAlias("video")
addAlias("sora")
addAlias(soraVideoFamilyAlias(modelID))
case "image":
addAlias("image")
addAlias("gpt-image")
case "prompt_enhance":
addAlias("prompt-enhance")
}
return aliases
}
switch {
case strings.HasPrefix(modelID, "sora"):
addAlias("video")
addAlias("sora")
addAlias(soraVideoFamilyAlias(modelID))
case strings.HasPrefix(modelID, "gpt-image"):
addAlias("image")
addAlias("gpt-image")
case strings.HasPrefix(modelID, "prompt-enhance"):
addAlias("prompt-enhance")
default:
return nil
}
return aliases
}
func soraVideoFamilyAlias(modelID string) string {
switch {
case strings.HasPrefix(modelID, "sora2pro-hd"):
return "sora2pro-hd"
case strings.HasPrefix(modelID, "sora2pro"):
return "sora2pro"
case strings.HasPrefix(modelID, "sora2"):
return "sora2"
default:
return ""
}
}
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
......@@ -3719,86 +4152,6 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
return result
}
// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages,
// system 字段仅保留 Claude Code 标识提示词。
// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词
// 无法通过检测,因为后续内容仍为非 Claude Code 格式。
// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对,system 仅保留 Claude Code 标识。
func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
system = normalizeSystemParam(system)
// 1. 提取原始 system prompt 文本
var originalSystemText string
switch v := system.(type) {
case string:
originalSystemText = strings.TrimSpace(v)
case []any:
var parts []string
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" {
parts = append(parts, text)
}
}
}
originalSystemText = strings.Join(parts, "\n\n")
}
// 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致)
// 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。
// 使用 string 格式会被 Anthropic 检测为第三方应用。
claudeCodeSystemBlock := []map[string]any{
{
"type": "text",
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{"type": "ephemeral"},
},
}
out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock)
if !ok {
logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt")
return body
}
// 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头
// 模型仍通过 messages 接收完整指令,保留客户端功能
ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt)
if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) {
instrMsg, err1 := json.Marshal(map[string]any{
"role": "user",
"content": []map[string]any{
{"type": "text", "text": "[System Instructions]\n" + originalSystemText},
},
})
ackMsg, err2 := json.Marshal(map[string]any{
"role": "assistant",
"content": []map[string]any{
{"type": "text", "text": "Understood. I will follow these instructions."},
},
})
if err1 != nil || err2 != nil {
logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection")
return out
}
// 重建 messages 数组:[instruction, ack, ...originalMessages]
items := [][]byte{instrMsg, ackMsg}
messagesResult := gjson.GetBytes(out, "messages")
if messagesResult.IsArray() {
messagesResult.ForEach(func(_, msg gjson.Result) bool {
items = append(items, []byte(msg.Raw))
return true
})
}
if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk {
out = next
}
}
return out
}
type cacheControlPath struct {
path string
log string
......@@ -3960,7 +4313,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
if account.Platform == PlatformAnthropic && c != nil {
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model)
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account)
if policy.blockErr != nil {
return nil, policy.blockErr
}
......@@ -3990,24 +4343,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
if shouldMimicClaudeCode {
// Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
systemRewritten := false
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
body = rewriteSystemForNonClaudeCode(body, parsed.System)
systemRewritten = true
body = injectClaudeCodePrompt(body, parsed.System)
}
// system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为);
// 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
if s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil && fp != nil {
// metadata 透传开启时跳过 metadata 注入
_, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx)
_, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx)
if !mimicMPT {
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
normalizeOpts.injectMetadata = true
......@@ -4054,12 +4402,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, err
}
// 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
// 获取代理URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
proxyURL = account.Proxy.URL()
}
proxyURL = account.Proxy.URL()
}
// 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析)
......@@ -4468,6 +4814,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 处理正常响应
ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed)
// 触发上游接受回调(提前释放串行锁,不等流完成)
if parsed.OnUpstreamAccepted != nil {
......@@ -5534,16 +5881,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
targetURL = validatedURL + "/v1/messages?beta=true"
}
} else if account.IsCustomBaseURLEnabled() {
customURL := account.GetCustomBaseURL()
if customURL == "" {
return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
}
validatedURL, err := s.validateUpstreamBaseURL(customURL)
if err != nil {
return nil, err
}
targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account)
}
clientHeaders := http.Header{}
......@@ -5553,9 +5890,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// OAuth账号:应用统一指纹和metadata重写(受设置开关控制)
var fingerprint *Fingerprint
enableFP, enableMPT, enableCCH := true, false, false
enableFP, enableMPT := true, false
if s.settingService != nil {
enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
}
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹(包含随机生成的ClientID)
......@@ -5582,15 +5919,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// 同步 billing header cc_version 与实际发送的 User-Agent 版本
if fingerprint != nil {
body = syncBillingHeaderVersion(body, fingerprint.UserAgent)
}
// CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后)
if enableCCH {
body = signBillingHeaderCCH(body)
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
......@@ -5631,8 +5959,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID)
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account)
effectiveDropSet := mergeDropSets(policyFilterSet)
effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode)
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
if tokenType == "oauth" {
......@@ -5643,16 +5972,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
applyClaudeCodeMimicHeaders(req, reqStream)
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
// Claude Code OAuth credentials are scoped to Claude Code.
// Non-haiku models MUST include claude-code beta for Anthropic to recognize
// this as a legitimate Claude Code request; without it, the request is
// rejected as third-party ("out of extra usage").
// Haiku models are exempt from third-party detection and don't need it.
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
if !strings.Contains(strings.ToLower(modelID), "haiku") {
requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking}
}
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet))
} else {
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
......@@ -5672,15 +5996,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
}
}
// === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{
"url": req.URL.String(),
......@@ -5875,7 +6190,7 @@ type betaPolicyResult struct {
}
// evaluateBetaPolicy loads settings once and evaluates all rules against the given request.
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult {
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult {
if s.settingService == nil {
return betaPolicyResult{}
}
......@@ -5890,11 +6205,10 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model)
switch effectiveAction {
switch rule.Action {
case BetaPolicyActionBlock:
if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) {
msg := effectiveErrMsg
msg := rule.ErrorMessage
if msg == "" {
msg = "beta feature " + rule.BetaToken + " is not allowed"
}
......@@ -5936,7 +6250,7 @@ const betaPolicyFilterSetKey = "betaPolicyFilterSet"
// In the /v1/messages path, Forward() evaluates the policy first and caches the result;
// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this
// evaluates on demand (one DB call).
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} {
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} {
if c != nil {
if v, ok := c.Get(betaPolicyFilterSetKey); ok {
if fs, ok := v.(map[string]struct{}); ok {
......@@ -5944,7 +6258,7 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont
}
}
}
return s.evaluateBetaPolicy(ctx, "", account, model).filterSet
return s.evaluateBetaPolicy(ctx, "", account).filterSet
}
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
......@@ -5963,33 +6277,6 @@ func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool {
}
}
// matchModelWhitelist checks if a model matches any pattern in the whitelist.
// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching.
func matchModelWhitelist(model string, whitelist []string) bool {
for _, pattern := range whitelist {
if matchModelPattern(pattern, model) {
return true
}
}
return false
}
// resolveRuleAction determines the effective action and error message for a rule given the request model.
// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally.
// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others.
func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) {
if len(rule.ModelWhitelist) == 0 {
return rule.Action, rule.ErrorMessage
}
if matchModelWhitelist(model, rule.ModelWhitelist) {
return rule.Action, rule.ErrorMessage
}
if rule.FallbackAction != "" {
return rule.FallbackAction, rule.FallbackErrorMessage
}
return BetaPolicyActionPass, "" // default fallback: pass (fail-open)
}
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
func droppedBetaSet(extra ...string) map[string]struct{} {
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
......@@ -6036,7 +6323,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
modelID string,
) ([]string, error) {
// 1. 对原始 header 中的 beta token 做 block 检查(快速失败)
policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID)
policy := s.evaluateBetaPolicy(ctx, betaHeader, account)
if policy.blockErr != nil {
return nil, policy.blockErr
}
......@@ -6048,7 +6335,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
// 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token,
// 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 →
// 如果不做此检查,block 规则会被绕过。
if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil {
if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil {
return nil, blockErr
}
......@@ -6057,7 +6344,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。
// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。
func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError {
func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError {
if s.settingService == nil || len(tokens) == 0 {
return nil
}
......@@ -6069,15 +6356,14 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke
isBedrock := account.IsBedrock()
tokenSet := buildBetaTokenSet(tokens)
for _, rule := range settings.Rules {
effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model)
if effectiveAction != BetaPolicyActionBlock {
if rule.Action != BetaPolicyActionBlock {
continue
}
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
continue
}
if _, present := tokenSet[rule.BetaToken]; present {
msg := effectiveErrMsg
msg := rule.ErrorMessage
if msg == "" {
msg = "beta feature " + rule.BetaToken + " is not allowed"
}
......@@ -6709,6 +6995,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
sawTerminalEvent := false
skipAccountTTLOverride := false
pendingEventLines := make([]string, 0, 4)
......@@ -6770,17 +7057,25 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
eventChanged = reconcileCachedTokens(u) || eventChanged
claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
if claudeMaxOutcome.Simulated {
skipAccountTTLOverride = true
}
}
}
}
if eventType == "message_delta" {
if u, ok := event["usage"].(map[string]any); ok {
eventChanged = reconcileCachedTokens(u) || eventChanged
claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
if claudeMaxOutcome.Simulated {
skipAccountTTLOverride = true
}
}
}
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() {
if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride {
overrideTarget := account.GetCacheTTLOverrideTarget()
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
......@@ -7212,8 +7507,13 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID)
if claudeMaxOutcome.Simulated {
body = rewriteClaudeUsageJSONBytes(body, response.Usage)
}
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() {
if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated {
overrideTarget := account.GetCacheTTLOverrideTarget()
if applyCacheTTLOverride(&response.Usage, overrideTarget) {
// 同步更新 body JSON 中的嵌套 cache_creation 对象
......@@ -7279,6 +7579,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ParsedRequest *ParsedRequest
APIKey *APIKey
User *User
Account *Account
......@@ -7437,6 +7738,9 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
cmd.CacheReadTokens = usageLog.CacheReadTokens
cmd.ImageCount = usageLog.ImageCount
if usageLog.MediaType != nil {
cmd.MediaType = *usageLog.MediaType
}
if usageLog.ServiceTier != nil {
cmd.ServiceTier = *usageLog.ServiceTier
}
......@@ -7592,6 +7896,8 @@ type recordUsageOpts struct {
// EnableClaudePath 启用 Claude 路径特有逻辑:
// - Claude Max 缓存计费策略
// - Sora 媒体类型分支(image/video/prompt)
// - MediaType 字段写入使用日志
EnableClaudePath bool
// 长上下文计费(仅 Gemini 路径需要)
......@@ -7616,6 +7922,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
APIKeyService: input.APIKeyService,
ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{
ParsedRequest: input.ParsedRequest,
EnableClaudePath: true,
})
}
......@@ -7682,6 +7989,7 @@ type recordUsageCoreInput struct {
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt)
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result
......@@ -7699,9 +8007,21 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
result.Usage.InputTokens = 0
}
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
// Claude Max cache billing policy(仅 Claude 路径启用)
cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() {
simulatedClaudeMax := false
if opts.EnableClaudePath {
var apiKeyGroup *Group
if apiKey != nil {
apiKeyGroup = apiKey.Group
}
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, opts.ParsedRequest, apiKeyGroup, result.Model, account.ID)
simulatedClaudeMax = claudeMaxOutcome.Simulated ||
(shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, opts.ParsedRequest) && hasCacheCreationTokens(result.Usage))
}
// Cache TTL Override: 确保计费时 token 分类与账号设置一致
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
}
......@@ -7783,6 +8103,16 @@ func (s *GatewayService) calculateRecordUsageCost(
multiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
// Sora 媒体类型分支(仅 Claude 路径启用)
if opts.EnableClaudePath {
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
}
if result.MediaType == MediaTypePrompt {
return &CostBreakdown{}
}
}
// 图片生成计费
if result.ImageCount > 0 {
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
......@@ -7792,6 +8122,28 @@ func (s *GatewayService) calculateRecordUsageCost(
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
}
// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
func (s *GatewayService) calculateSoraMediaCost(
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
) *CostBreakdown {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
ImagePrice360: apiKey.Group.SoraImagePrice360,
ImagePrice540: apiKey.Group.SoraImagePrice540,
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if result.MediaType == MediaTypeImage {
return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
}
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
......@@ -7814,7 +8166,7 @@ func (s *GatewayService) calculateImageCost(
billingModel string,
multiplier float64,
) *CostBreakdown {
if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
......@@ -7829,7 +8181,6 @@ func (s *GatewayService) calculateImageCost(
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
Resolved: resolved,
})
if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
......@@ -7872,7 +8223,7 @@ func (s *GatewayService) calculateTokenCost(
var err error
// 优先尝试渠道定价 → CalculateCostUnified
if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil {
gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx,
......@@ -7882,7 +8233,6 @@ func (s *GatewayService) calculateTokenCost(
RequestCount: 1,
RateMultiplier: multiplier,
Resolver: s.resolver,
Resolved: resolved,
})
} else if opts.LongContextThreshold > 0 {
// 长上下文双倍计费(如 Gemini 200K 阈值)
......@@ -7940,12 +8290,13 @@ func (s *GatewayService) buildRecordUsageLog(
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
BillingMode: resolveBillingMode(result, cost),
BillingMode: resolveBillingMode(opts, result, cost),
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
MediaType: resolveMediaType(opts, result),
CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
......@@ -7969,7 +8320,13 @@ func (s *GatewayService) buildRecordUsageLog(
}
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
isSoraMedia := opts.EnableClaudePath &&
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
if isSoraMedia {
return nil
}
var mode string
switch {
case cost != nil && cost.BillingMode != "":
......@@ -7982,6 +8339,13 @@ func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
return &mode
}
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
return &result.MediaType
}
return nil
}
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
if subscription != nil {
return &subscription.ID
......@@ -8010,8 +8374,8 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m
return s.channelService.IsModelRestricted(ctx, groupID, model)
}
// ResolveChannelMappingAndRestrict 解析渠道映射。
// 模型限制检查已移至调度阶段(checkChannelPricingRestriction),restricted 始终返回 false
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制
// 返回映射结果和是否被限制
func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if s.channelService == nil {
return ChannelMappingResult{MappedModel: model}, false
......@@ -8042,9 +8406,7 @@ func billingModelForRestriction(source, requestedModel, channelMappedModel strin
return requestedModel
case BillingModelSourceUpstream:
return ""
case BillingModelSourceChannelMapped:
return channelMappedModel
default:
default: // channel_mapped
return channelMappedModel
}
}
......@@ -8076,11 +8438,7 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil {
slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err)
return false
}
if ch == nil || !ch.RestrictModels {
if err != nil || ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream
......@@ -8172,12 +8530,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return err
}
// 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
// 获取代理URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
proxyURL = account.Proxy.URL()
}
proxyURL = account.Proxy.URL()
}
// 发送请求
......@@ -8456,16 +8812,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
}
} else if account.IsCustomBaseURLEnabled() {
customURL := account.GetCustomBaseURL()
if customURL == "" {
return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
}
validatedURL, err := s.validateUpstreamBaseURL(customURL)
if err != nil {
return nil, err
}
targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account)
}
clientHeaders := http.Header{}
......@@ -8475,9 +8821,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用统一指纹和重写 userID(受设置开关控制)
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false
ctEnableFP, ctEnableMPT := true, false
if s.settingService != nil {
ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
}
var ctFingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil {
......@@ -8495,14 +8841,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
// 同步 billing header cc_version 与实际发送的 User-Agent 版本
if ctFingerprint != nil && ctEnableFP {
body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent)
}
if ctEnableCCH {
body = signBillingHeaderCCH(body)
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
......@@ -8543,7 +8881,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID))
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account))
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
......@@ -8579,15 +8917,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
}
}
if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
}
......@@ -8609,19 +8938,6 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
})
}
// buildCustomRelayURL 构建自定义中继转发 URL
// 在 path 后附加 beta=true 和可选的 proxy 查询参数
func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string {
u := strings.TrimRight(baseURL, "/") + path + "?beta=true"
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
u += "&proxy=" + url.QueryEscape(proxyURL)
}
}
return u
}
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment