Commit 37c23ecc authored by erio's avatar erio
Browse files

fix: gofmt formatting

parent e3748741
...@@ -134,7 +134,7 @@ func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) Chan ...@@ -134,7 +134,7 @@ func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) Chan
const ( const (
channelCacheTTL = 10 * time.Minute channelCacheTTL = 10 * time.Minute
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
channelCacheDBTimeout = 10 * time.Second channelCacheDBTimeout = 10 * time.Second
) )
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"log/slog" "log/slog"
mathrand "math/rand" mathrand "math/rand"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
...@@ -41,8 +42,7 @@ import ( ...@@ -41,8 +42,7 @@ import (
const ( const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL stickySessionTTL = time.Hour // 粘性会话TTL
ClientAffinityTTL = 24 * time.Hour // 客户端亲和TTL
defaultMaxLineSize = 500 * 1024 * 1024 defaultMaxLineSize = 500 * 1024 * 1024
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines) // Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
// to match real Claude CLI traffic as closely as possible. When we need a visual // to match real Claude CLI traffic as closely as possible. When we need a visual
...@@ -60,28 +60,14 @@ const ( ...@@ -60,28 +60,14 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info" claudeMimicDebugInfoKey = "claude_mimic_debug_info"
) )
// MediaType 媒体类型常量
const (
MediaTypeImage = "image"
MediaTypeVideo = "video"
MediaTypePrompt = "prompt"
)
const (
claudeMaxMessageOverheadTokens = 3
claudeMaxBlockOverheadTokens = 1
claudeMaxUnknownContentTokens = 4
)
// ForceCacheBillingContextKey 强制缓存计费上下文键 // ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{} type forceCacheBillingKeyType struct{}
// accountWithLoad 账号与负载信息的组合,用于负载感知调度 // accountWithLoad 账号与负载信息的组合,用于负载感知调度
type accountWithLoad struct { type accountWithLoad struct {
account *Account account *Account
loadInfo *AccountLoadInfo loadInfo *AccountLoadInfo
affinityCount int64 // 亲和客户端数量(反向索引),越少越优先
} }
var ForceCacheBillingContextKey = forceCacheBillingKeyType{} var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
...@@ -345,10 +331,6 @@ var ( ...@@ -345,10 +331,6 @@ var (
sseDataRe = regexp.MustCompile(`^data:\s*`) sseDataRe = regexp.MustCompile(`^data:\s*`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
// clientIDFromMetadataRegex 从 metadata.user_id 中提取客户端 ID(64位 hex)
// 格式: user_{64位hex}_account_...
clientIDFromMetadataRegex = regexp.MustCompile(`^user_([a-f0-9]{64})_account_`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
// 注意:前缀之间不应存在包含关系,否则会导致冗余匹配 // 注意:前缀之间不应存在包含关系,否则会导致冗余匹配
...@@ -366,12 +348,6 @@ var ErrNoAvailableAccounts = errors.New("no available accounts") ...@@ -366,12 +348,6 @@ var ErrNoAvailableAccounts = errors.New("no available accounts")
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// ErrAffinityNoSwitch 表示亲和账号不可用且不允许切换到其他账号
var ErrAffinityNoSwitch = errors.New("affinity account unavailable and switching is disabled")
// ErrAffinityLimitExceeded 表示亲和客户端限制已达上限
var ErrAffinityLimitExceeded = errors.New("affinity client limit exceeded")
// allowedHeaders 白名单headers(参考CRS项目) // allowedHeaders 白名单headers(参考CRS项目)
var allowedHeaders = map[string]bool{ var allowedHeaders = map[string]bool{
"accept": true, "accept": true,
...@@ -393,6 +369,8 @@ var allowedHeaders = map[string]bool{ ...@@ -393,6 +369,8 @@ var allowedHeaders = map[string]bool{
"user-agent": true, "user-agent": true,
"content-type": true, "content-type": true,
"accept-encoding": true, "accept-encoding": true,
"x-claude-code-session-id": true,
"x-client-request-id": true,
} }
// GatewayCache 定义网关服务的缓存操作接口。 // GatewayCache 定义网关服务的缓存操作接口。
...@@ -413,39 +391,6 @@ type GatewayCache interface { ...@@ -413,39 +391,6 @@ type GatewayCache interface {
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable // Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
// GetAffinityAccounts 获取亲和账号列表(按最近使用降序),同时清理过期成员
GetAffinityAccounts(ctx context.Context, groupID int64, userID int64, clientID string, ttl time.Duration) ([]int64, error)
// UpdateAffinity 添加/更新亲和关系(更新 score 为当前时间戳,刷新 key TTL)
UpdateAffinity(ctx context.Context, groupID int64, userID int64, clientID string, accountID int64, ttl time.Duration) error
// GetAccountAffinityCountBatch 批量获取账号的亲和成员数量(惰性清理过期成员)
GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error)
// GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和成员列表(去重)
// accountGroups: map[accountID][]groupID
// 返回值成员格式为 {userID}/{clientID}
GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error)
// GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间)
GetAccountAffinityClientsWithScores(ctx context.Context, accountID int64, groupIDs []int64, ttl time.Duration) ([]AffinityClient, error)
// ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)
// 用于账号关闭亲和时立即清理旧绑定
ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error
// GetAffinityMultiCount 获取账号的多维度亲和计数
// 返回: uniqueUsers, uniqueClients, perUserClients
GetAffinityMultiCount(ctx context.Context, groupID int64, accountID int64, targetUserID int64, ttl time.Duration) (users, clients, perUser int64, err error)
}
// AffinityClient 亲和客户端信息(含用户 ID 和最后活跃时间)
type AffinityClient struct {
UserID int64 `json:"user_id"`
ClientID string `json:"client_id"`
LastActive time.Time `json:"last_active"`
}
// SortAffinityClients 按最后活跃时间降序排序
func SortAffinityClients(clients []AffinityClient) {
sort.Slice(clients, func(i, j int) bool {
return clients[i].LastActive.After(clients[j].LastActive)
})
} }
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil // derefGroupID safely dereferences *int64 to int64, returning 0 if nil
...@@ -516,20 +461,6 @@ func shouldClearStickySession(account *Account, requestedModel string) bool { ...@@ -516,20 +461,6 @@ func shouldClearStickySession(account *Account, requestedModel string) bool {
return false return false
} }
// extractClientIDFromMetadata 从 metadata.user_id 中提取客户端 ID(64位 hex)。
// 格式: user_{64位hex}_account_..._session_...
// 返回空字符串表示无法提取(非 Claude Code/Console 客户端)。
func extractClientIDFromMetadata(metadataUserID string) string {
if metadataUserID == "" {
return ""
}
matches := clientIDFromMetadataRegex.FindStringSubmatch(metadataUserID)
if matches == nil {
return ""
}
return matches[1]
}
type AccountWaitPlan struct { type AccountWaitPlan struct {
AccountID int64 AccountID int64
MaxConcurrency int MaxConcurrency int
...@@ -572,10 +503,6 @@ type ForwardResult struct { ...@@ -572,10 +503,6 @@ type ForwardResult struct {
// 图片生成计费字段(图片生成模型使用) // 图片生成计费字段(图片生成模型使用)
ImageCount int // 生成的图片数量 ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K" ImageSize string // 图片尺寸 "1K", "2K", "4K"
// Sora 媒体字段
MediaType string // image / video / prompt
MediaURL string // 生成后的媒体地址(可选)
} }
// UpstreamFailoverError indicates an upstream error that should trigger account failover. // UpstreamFailoverError indicates an upstream error that should trigger account failover.
...@@ -1315,10 +1242,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1315,10 +1242,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
} }
// 提取客户端 ID(用于客户端亲和调度)
affinityClientID := extractClientIDFromMetadata(metadataUserID)
affinityUserID := sub2apiUserID
if s.debugModelRoutingEnabled() && requestedModel != "" { if s.debugModelRoutingEnabled() && requestedModel != "" {
groupPlatform := "" groupPlatform := ""
if group != nil { if group != nil {
...@@ -1340,10 +1263,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1340,10 +1263,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if err != nil { if err != nil {
return nil, err return nil, err
} }
if shouldFilterAccountWithoutClientID(account, affinityClientID) {
localExcluded[account.ID] = struct{}{}
continue
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
...@@ -1405,7 +1324,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1405,7 +1324,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if err != nil { if err != nil {
return nil, err return nil, err
} }
accounts = filterAccountsWithoutClientID(accounts, affinityClientID)
if len(accounts) == 0 { if len(accounts) == 0 {
return nil, ErrNoAvailableAccounts return nil, ErrNoAvailableAccounts
} }
...@@ -1424,19 +1342,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1424,19 +1342,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
_, excluded := excludedIDs[accountID] _, excluded := excludedIDs[accountID]
return excluded return excluded
} }
affinityFlow := newGatewayAffinityFlow(
s,
ctx,
groupID,
sessionHash,
requestedModel,
affinityClientID,
affinityUserID,
platform,
useMixed,
accountByID,
isExcluded,
)
// 获取模型路由配置(仅 anthropic 平台) // 获取模型路由配置(仅 anthropic 平台)
var routingAccountIDs []int64 var routingAccountIDs []int64
...@@ -1599,10 +1504,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1599,10 +1504,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
if len(routingAvailable) > 0 { if len(routingAvailable) > 0 {
// 批量获取亲和客户端数量 // 排序:优先级 > 负载率 > 最后使用时间
s.populateAffinityCounts(ctx, routingAvailable, derefGroupID(groupID))
// 排序:优先级 > 负载率 > 亲和客户端数 > 最后使用时间
sort.SliceStable(routingAvailable, func(i, j int) bool { sort.SliceStable(routingAvailable, func(i, j int) bool {
a, b := routingAvailable[i], routingAvailable[j] a, b := routingAvailable[i], routingAvailable[j]
if a.account.Priority != b.account.Priority { if a.account.Priority != b.account.Priority {
...@@ -1611,9 +1513,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1611,9 +1513,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if a.loadInfo.LoadRate != b.loadInfo.LoadRate { if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate return a.loadInfo.LoadRate < b.loadInfo.LoadRate
} }
if a.affinityCount != b.affinityCount {
return a.affinityCount < b.affinityCount
}
switch { switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true return true
...@@ -1639,9 +1538,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1639,9 +1538,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
} }
if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && item.account.IsAffinityEnabled() {
_ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, item.account.ID, ClientAffinityTTL)
}
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
} }
...@@ -1679,22 +1575,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1679,22 +1575,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
} }
// ============ Layer 1.3: 用户亲和预处理(pinned_users 自动注入) ============ // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
affinityFlow.preprocessPinnedUsers(accounts) if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) {
// ============ Layer 1.4: 客户端亲和调度(优先于粘性会话) ============
affinityHit := false
if affinityResult, hit, err := affinityFlow.trySelectAffinityAccount(); err != nil {
return nil, err
} else {
affinityHit = hit
if affinityResult != nil {
return affinityResult, nil
}
}
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置 且 亲和未命中时生效) ============
if !affinityHit && len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) {
accountID := stickyAccountID accountID := stickyAccountID
if accountID > 0 && !isExcluded(accountID) { if accountID > 0 && !isExcluded(accountID) {
account, ok := accountByID[accountID] account, ok := accountByID[accountID]
...@@ -1800,9 +1682,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1800,9 +1682,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil { if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && result.Account != nil && result.Account.IsAffinityEnabled() {
_ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, result.Account.ID, ClientAffinityTTL)
}
return result, nil return result, nil
} }
} else { } else {
...@@ -1820,37 +1699,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1820,37 +1699,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
} }
// 批量获取亲和客户端数量(用于均衡分配新客户端) // 分层过滤选择:优先级 → 负载率 → LRU
s.populateAffinityCounts(ctx, available, derefGroupID(groupID))
// 分层过滤选择:优先级 → 亲和三区 → 负载率 → 亲和客户端数 → LRU
for len(available) > 0 { for len(available) > 0 {
// 1. 取优先级最小的集合 // 1. 取优先级最小的集合
candidates := filterByMinPriority(available) candidates := filterByMinPriority(available)
// 2. 按亲和三区过滤:绿区优先 → 黄区降级 → 红区移除(在同优先级内) // 2. 取负载率最低的集合
candidates = classifyByAffinityZone(candidates)
if len(candidates) == 0 {
// 当前优先级组全部在红区,移除后回退到下一优先级组
minPri := available[0].account.Priority
for _, a := range available[1:] {
if a.account.Priority < minPri {
minPri = a.account.Priority
}
}
newAvailable := make([]accountWithLoad, 0, len(available))
for _, a := range available {
if a.account.Priority != minPri {
newAvailable = append(newAvailable, a)
}
}
available = newAvailable
continue
}
// 3. 取负载率最低的集合
candidates = filterByMinLoadRate(candidates) candidates = filterByMinLoadRate(candidates)
// 3. 取亲和客户端数最少的集合 // 3. LRU 选择最久未用的账号
candidates = filterByMinAffinityCount(candidates)
// 4. LRU 选择最久未用的账号
selected := selectByLRU(candidates, preferOAuth) selected := selectByLRU(candidates, preferOAuth)
if selected == nil { if selected == nil {
break break
...@@ -1865,10 +1720,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1865,10 +1720,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
} }
// 更新亲和关系
if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && selected.account.IsAffinityEnabled() {
_ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, selected.account.ID, ClientAffinityTTL)
}
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: selected.account, Account: selected.account,
Acquired: true, Acquired: true,
...@@ -2077,9 +1928,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr ...@@ -2077,9 +1928,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
} }
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
if platform == PlatformSora {
return s.listSoraSchedulableAccounts(ctx, groupID)
}
if s.schedulerSnapshot != nil { if s.schedulerSnapshot != nil {
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil { if err == nil {
...@@ -2176,53 +2024,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i ...@@ -2176,53 +2024,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
return accounts, useMixed, nil return accounts, useMixed, nil
} }
func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) {
const useMixed = false
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
} else if groupID != nil {
accounts, err = s.accountRepo.ListByGroup(ctx, *groupID)
} else {
accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora)
}
if err != nil {
slog.Debug("account_scheduling_list_failed",
"group_id", derefGroupID(groupID),
"platform", PlatformSora,
"error", err)
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform != PlatformSora {
continue
}
if !s.isSoraAccountSchedulable(&acc) {
continue
}
filtered = append(filtered, acc)
}
slog.Debug("account_scheduling_list_sora",
"group_id", derefGroupID(groupID),
"platform", PlatformSora,
"raw_count", len(accounts),
"filtered_count", len(filtered))
for _, acc := range filtered {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
return filtered, useMixed, nil
}
// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 // IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。
// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, // 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context,
// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 // 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。
...@@ -2247,33 +2048,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform ...@@ -2247,33 +2048,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
return account.Platform == platform return account.Platform == platform
} }
func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool {
return s.soraUnschedulableReason(account) == ""
}
func (s *GatewayService) soraUnschedulableReason(account *Account) string {
if account == nil {
return "account_nil"
}
if account.Status != StatusActive {
return fmt.Sprintf("status=%s", account.Status)
}
if !account.Schedulable {
return "schedulable=false"
}
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339))
}
return ""
}
func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool {
if account == nil { if account == nil {
return false return false
} }
if account.Platform == PlatformSora {
return s.isSoraAccountSchedulable(account)
}
return account.IsSchedulable() return account.IsSchedulable()
} }
...@@ -2281,12 +2059,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte ...@@ -2281,12 +2059,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte
if account == nil { if account == nil {
return false return false
} }
if account.Platform == PlatformSora {
if !s.isSoraAccountSchedulable(account) {
return false
}
return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0
}
return account.IsSchedulableForModelWithContext(ctx, requestedModel) return account.IsSchedulableForModelWithContext(ctx, requestedModel)
} }
...@@ -2626,36 +2398,6 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in ...@@ -2626,36 +2398,6 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return s.accountRepo.GetByID(ctx, accountID) return s.accountRepo.GetByID(ctx, accountID)
} }
// populateAffinityCounts 批量获取账号的亲和客户端数量并填入 accountWithLoad 切片。
// 仅当存在开启了客户端亲和的账号时才查询 Redis,否则跳过。
func (s *GatewayService) populateAffinityCounts(ctx context.Context, accounts []accountWithLoad, groupID int64) {
if s.cache == nil || len(accounts) == 0 {
return
}
// 快速检查:是否有任何账号开启了亲和
hasAffinity := false
for _, acc := range accounts {
if acc.account.IsAffinityEnabled() {
hasAffinity = true
break
}
}
if !hasAffinity {
return
}
accountIDs := make([]int64, len(accounts))
for i, acc := range accounts {
accountIDs[i] = acc.account.ID
}
countMap, err := s.cache.GetAccountAffinityCountBatch(ctx, groupID, accountIDs, ClientAffinityTTL)
if err != nil {
return // 查询失败不影响调度,affinityCount 保持 0
}
for i := range accounts {
accounts[i].affinityCount = countMap[accounts[i].account.ID]
}
}
// filterByMinPriority 过滤出优先级最小的账号集合 // filterByMinPriority 过滤出优先级最小的账号集合
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 { if len(accounts) == 0 {
...@@ -2696,64 +2438,6 @@ func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad { ...@@ -2696,64 +2438,6 @@ func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad {
return result return result
} }
// filterByMinAffinityCount 过滤出亲和客户端数最少的账号集合
func filterByMinAffinityCount(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
return accounts
}
minCount := accounts[0].affinityCount
for _, acc := range accounts[1:] {
if acc.affinityCount < minCount {
minCount = acc.affinityCount
}
}
result := make([]accountWithLoad, 0, len(accounts))
for _, acc := range accounts {
if acc.affinityCount == minCount {
result = append(result, acc)
}
}
return result
}
// classifyByAffinityZone 按亲和分区对候选账号进行分类。
// 返回值:仅绿区账号(有绿区时),否则返回黄区账号。红区账号被移除。
// 如果没有任何账号开启了亲和三区配置(即 affinity_base <= 0),则原样返回所有账号。
func classifyByAffinityZone(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 {
return accounts
}
// 快速检查:是否有任何账号配置了 affinity_base
hasZoneConfig := false
for _, acc := range accounts {
if acc.account.IsAffinityEnabled() && acc.account.GetAffinityBase() > 0 {
hasZoneConfig = true
break
}
}
if !hasZoneConfig {
return accounts
}
greens := make([]accountWithLoad, 0, len(accounts))
yellows := make([]accountWithLoad, 0, len(accounts))
for _, acc := range accounts {
zone := acc.account.GetAffinityZone(acc.affinityCount)
switch zone {
case AffinityZoneGreen:
greens = append(greens, acc)
case AffinityZoneYellow:
yellows = append(yellows, acc)
case AffinityZoneRed:
// 红区:移除,不参与调度
}
}
if len(greens) > 0 {
return greens
}
return yellows
}
// selectByLRU 从集合中选择最久未用的账号 // selectByLRU 从集合中选择最久未用的账号
// 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个 // 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个
func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad { func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad {
...@@ -3514,9 +3198,6 @@ func (s *GatewayService) logDetailedSelectionFailure( ...@@ -3514,9 +3198,6 @@ func (s *GatewayService) logDetailedSelectionFailure(
stats.SampleMappingIDs, stats.SampleMappingIDs,
stats.SampleRateLimitIDs, stats.SampleRateLimitIDs,
) )
if platform == PlatformSora {
s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling)
}
return stats return stats
} }
...@@ -3574,9 +3255,6 @@ func (s *GatewayService) diagnoseSelectionFailure( ...@@ -3574,9 +3255,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
} }
if !s.isAccountSchedulableForSelection(acc) { if !s.isAccountSchedulableForSelection(acc) {
detail := "generic_unschedulable" detail := "generic_unschedulable"
if acc.Platform == PlatformSora {
detail = s.soraUnschedulableReason(acc)
}
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
} }
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
...@@ -3601,57 +3279,7 @@ func (s *GatewayService) diagnoseSelectionFailure( ...@@ -3601,57 +3279,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "eligible"} return selectionFailureDiagnosis{Category: "eligible"}
} }
func (s *GatewayService) logSoraSelectionFailureDetails( // GetAccessToken 获取账号凭证
ctx context.Context,
groupID *int64,
sessionHash string,
requestedModel string,
accounts []Account,
excludedIDs map[int64]struct{},
allowMixedScheduling bool,
) {
const maxLines = 30
logged := 0
for i := range accounts {
if logged >= maxLines {
break
}
acc := &accounts[i]
diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling)
if diagnosis.Category == "eligible" {
continue
}
detail := diagnosis.Detail
if detail == "" {
detail = "-"
}
logger.LegacyPrintf(
"service.gateway",
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s",
derefGroupID(groupID),
requestedModel,
shortSessionHash(sessionHash),
acc.ID,
acc.Platform,
diagnosis.Category,
detail,
)
logged++
}
if len(accounts) > maxLines {
logger.LegacyPrintf(
"service.gateway",
"[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d",
derefGroupID(groupID),
requestedModel,
shortSessionHash(sessionHash),
len(accounts),
logged,
)
}
}
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
if acc == nil { if acc == nil {
return true return true
...@@ -3730,9 +3358,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo ...@@ -3730,9 +3358,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
} }
return mapAntigravityModel(account, requestedModel) != "" return mapAntigravityModel(account, requestedModel) != ""
} }
if account.Platform == PlatformSora {
return s.isSoraModelSupportedByAccount(account, requestedModel)
}
if account.IsBedrock() { if account.IsBedrock() {
_, ok := ResolveBedrockModelID(account, requestedModel) _, ok := ResolveBedrockModelID(account, requestedModel)
return ok return ok
...@@ -3749,143 +3374,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo ...@@ -3749,143 +3374,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
return account.IsModelSupported(requestedModel) return account.IsModelSupported(requestedModel)
} }
func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool {
if account == nil {
return false
}
if strings.TrimSpace(requestedModel) == "" {
return true
}
// 先走原始精确/通配符匹配。
mapping := account.GetModelMapping()
if len(mapping) == 0 || account.IsModelSupported(requestedModel) {
return true
}
aliases := buildSoraModelAliases(requestedModel)
if len(aliases) == 0 {
return false
}
hasSoraSelector := false
for pattern := range mapping {
if !isSoraModelSelector(pattern) {
continue
}
hasSoraSelector = true
if matchPatternAnyAlias(pattern, aliases) {
return true
}
}
// 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*),
// 此时不应误拦截 Sora 模型请求。
if !hasSoraSelector {
return true
}
return false
}
func matchPatternAnyAlias(pattern string, aliases []string) bool {
normalizedPattern := strings.ToLower(strings.TrimSpace(pattern))
if normalizedPattern == "" {
return false
}
for _, alias := range aliases {
if matchWildcard(normalizedPattern, alias) {
return true
}
}
return false
}
func isSoraModelSelector(pattern string) bool {
p := strings.ToLower(strings.TrimSpace(pattern))
if p == "" {
return false
}
switch {
case strings.HasPrefix(p, "sora"),
strings.HasPrefix(p, "gpt-image"),
strings.HasPrefix(p, "prompt-enhance"),
strings.HasPrefix(p, "sy_"):
return true
}
return p == "video" || p == "image"
}
func buildSoraModelAliases(requestedModel string) []string {
modelID := strings.ToLower(strings.TrimSpace(requestedModel))
if modelID == "" {
return nil
}
aliases := make([]string, 0, 8)
addAlias := func(value string) {
v := strings.ToLower(strings.TrimSpace(value))
if v == "" {
return
}
for _, existing := range aliases {
if existing == v {
return
}
}
aliases = append(aliases, v)
}
addAlias(modelID)
cfg, ok := GetSoraModelConfig(modelID)
if ok {
addAlias(cfg.Model)
switch cfg.Type {
case "video":
addAlias("video")
addAlias("sora")
addAlias(soraVideoFamilyAlias(modelID))
case "image":
addAlias("image")
addAlias("gpt-image")
case "prompt_enhance":
addAlias("prompt-enhance")
}
return aliases
}
switch {
case strings.HasPrefix(modelID, "sora"):
addAlias("video")
addAlias("sora")
addAlias(soraVideoFamilyAlias(modelID))
case strings.HasPrefix(modelID, "gpt-image"):
addAlias("image")
addAlias("gpt-image")
case strings.HasPrefix(modelID, "prompt-enhance"):
addAlias("prompt-enhance")
default:
return nil
}
return aliases
}
func soraVideoFamilyAlias(modelID string) string {
switch {
case strings.HasPrefix(modelID, "sora2pro-hd"):
return "sora2pro-hd"
case strings.HasPrefix(modelID, "sora2pro"):
return "sora2pro"
case strings.HasPrefix(modelID, "sora2"):
return "sora2"
default:
return ""
}
}
// GetAccessToken 获取账号凭证 // GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type { switch account.Type {
...@@ -4412,10 +3900,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4412,10 +3900,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, err return nil, err
} }
// 获取代理URL // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
proxyURL := "" proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL() if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
proxyURL = account.Proxy.URL()
}
} }
// 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析) // 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析)
...@@ -4824,7 +4314,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -4824,7 +4314,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} }
// 处理正常响应 // 处理正常响应
ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed)
// 触发上游接受回调(提前释放串行锁,不等流完成) // 触发上游接受回调(提前释放串行锁,不等流完成)
if parsed.OnUpstreamAccepted != nil { if parsed.OnUpstreamAccepted != nil {
...@@ -5891,6 +5380,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -5891,6 +5380,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
targetURL = validatedURL + "/v1/messages?beta=true" targetURL = validatedURL + "/v1/messages?beta=true"
} }
} else if account.IsCustomBaseURLEnabled() {
customURL := account.GetCustomBaseURL()
if customURL == "" {
return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
}
validatedURL, err := s.validateUpstreamBaseURL(customURL)
if err != nil {
return nil, err
}
targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account)
} }
clientHeaders := http.Header{} clientHeaders := http.Header{}
...@@ -6006,6 +5505,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -6006,6 +5505,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
} }
} }
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
}
}
// === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 === // === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 ===
s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{ s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{
"url": req.URL.String(), "url": req.URL.String(),
...@@ -7005,7 +6513,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -7005,7 +6513,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
sawTerminalEvent := false sawTerminalEvent := false
skipAccountTTLOverride := false
pendingEventLines := make([]string, 0, 4) pendingEventLines := make([]string, 0, 4)
...@@ -7067,25 +6574,17 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ...@@ -7067,25 +6574,17 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if msg, ok := event["message"].(map[string]any); ok { if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok {
eventChanged = reconcileCachedTokens(u) || eventChanged eventChanged = reconcileCachedTokens(u) || eventChanged
claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
if claudeMaxOutcome.Simulated {
skipAccountTTLOverride = true
}
} }
} }
} }
if eventType == "message_delta" { if eventType == "message_delta" {
if u, ok := event["usage"].(map[string]any); ok { if u, ok := event["usage"].(map[string]any); ok {
eventChanged = reconcileCachedTokens(u) || eventChanged eventChanged = reconcileCachedTokens(u) || eventChanged
claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID)
if claudeMaxOutcome.Simulated {
skipAccountTTLOverride = true
}
} }
} }
// Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride { if account.IsCacheTTLOverrideEnabled() {
overrideTarget := account.GetCacheTTLOverrideTarget() overrideTarget := account.GetCacheTTLOverrideTarget()
if eventType == "message_start" { if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok { if msg, ok := event["message"].(map[string]any); ok {
...@@ -7517,13 +7016,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h ...@@ -7517,13 +7016,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
} }
} }
claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID)
if claudeMaxOutcome.Simulated {
body = rewriteClaudeUsageJSONBytes(body, response.Usage)
}
// Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated { if account.IsCacheTTLOverrideEnabled() {
overrideTarget := account.GetCacheTTLOverrideTarget() overrideTarget := account.GetCacheTTLOverrideTarget()
if applyCacheTTLOverride(&response.Usage, overrideTarget) { if applyCacheTTLOverride(&response.Usage, overrideTarget) {
// 同步更新 body JSON 中的嵌套 cache_creation 对象 // 同步更新 body JSON 中的嵌套 cache_creation 对象
...@@ -7901,12 +7395,10 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage ...@@ -7901,12 +7395,10 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。 // recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
type recordUsageOpts struct { type recordUsageOpts struct {
// Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入) // ParsedRequest(可选,仅 Claude 路径传入)
ParsedRequest *ParsedRequest ParsedRequest *ParsedRequest
// EnableClaudePath 启用 Claude 路径特有逻辑: // EnableClaudePath 启用 Claude 路径特有逻辑:
// - Claude Max 缓存计费策略
// - Sora 媒体类型分支(image/video/prompt)
// - MediaType 字段写入使用日志 // - MediaType 字段写入使用日志
EnableClaudePath bool EnableClaudePath bool
...@@ -7998,8 +7490,6 @@ type recordUsageCoreInput struct { ...@@ -7998,8 +7490,6 @@ type recordUsageCoreInput struct {
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 // recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为: // opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt)
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext // - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result result := input.Result
...@@ -8017,21 +7507,9 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage ...@@ -8017,21 +7507,9 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
result.Usage.InputTokens = 0 result.Usage.InputTokens = 0
} }
// Claude Max cache billing policy(仅 Claude 路径启用)
cacheTTLOverridden := false
simulatedClaudeMax := false
if opts.EnableClaudePath {
var apiKeyGroup *Group
if apiKey != nil {
apiKeyGroup = apiKey.Group
}
claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, opts.ParsedRequest, apiKeyGroup, result.Model, account.ID)
simulatedClaudeMax = claudeMaxOutcome.Simulated ||
(shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, opts.ParsedRequest) && hasCacheCreationTokens(result.Usage))
}
// Cache TTL Override: 确保计费时 token 分类与账号设置一致 // Cache TTL Override: 确保计费时 token 分类与账号设置一致
if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax { cacheTTLOverridden := false
if account.IsCacheTTLOverrideEnabled() {
applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
} }
...@@ -8113,16 +7591,6 @@ func (s *GatewayService) calculateRecordUsageCost( ...@@ -8113,16 +7591,6 @@ func (s *GatewayService) calculateRecordUsageCost(
multiplier float64, multiplier float64,
opts *recordUsageOpts, opts *recordUsageOpts,
) *CostBreakdown { ) *CostBreakdown {
// Sora 媒体类型分支(仅 Claude 路径启用)
if opts.EnableClaudePath {
if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo {
return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier)
}
if result.MediaType == MediaTypePrompt {
return &CostBreakdown{}
}
}
// 图片生成计费 // 图片生成计费
if result.ImageCount > 0 { if result.ImageCount > 0 {
return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier) return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
...@@ -8132,28 +7600,6 @@ func (s *GatewayService) calculateRecordUsageCost( ...@@ -8132,28 +7600,6 @@ func (s *GatewayService) calculateRecordUsageCost(
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts) return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
} }
// calculateSoraMediaCost 计算 Sora 图片/视频的费用。
func (s *GatewayService) calculateSoraMediaCost(
result *ForwardResult,
apiKey *APIKey,
billingModel string,
multiplier float64,
) *CostBreakdown {
var soraConfig *SoraPriceConfig
if apiKey.Group != nil {
soraConfig = &SoraPriceConfig{
ImagePrice360: apiKey.Group.SoraImagePrice360,
ImagePrice540: apiKey.Group.SoraImagePrice540,
VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
}
}
if result.MediaType == MediaTypeImage {
return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier)
}
return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier)
}
// resolveChannelPricing 检查指定模型是否存在渠道级别定价。 // resolveChannelPricing 检查指定模型是否存在渠道级别定价。
// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。 // 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。
func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
...@@ -8176,7 +7622,7 @@ func (s *GatewayService) calculateImageCost( ...@@ -8176,7 +7622,7 @@ func (s *GatewayService) calculateImageCost(
billingModel string, billingModel string,
multiplier float64, multiplier float64,
) *CostBreakdown { ) *CostBreakdown {
if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil { if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
tokens := UsageTokens{ tokens := UsageTokens{
InputTokens: result.Usage.InputTokens, InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
...@@ -8191,6 +7637,7 @@ func (s *GatewayService) calculateImageCost( ...@@ -8191,6 +7637,7 @@ func (s *GatewayService) calculateImageCost(
RequestCount: 1, RequestCount: 1,
RateMultiplier: multiplier, RateMultiplier: multiplier,
Resolver: s.resolver, Resolver: s.resolver,
Resolved: resolved,
}) })
if err != nil { if err != nil {
logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err) logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err)
...@@ -8233,7 +7680,7 @@ func (s *GatewayService) calculateTokenCost( ...@@ -8233,7 +7680,7 @@ func (s *GatewayService) calculateTokenCost(
var err error var err error
// 优先尝试渠道定价 → CalculateCostUnified // 优先尝试渠道定价 → CalculateCostUnified
if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil { if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
gid := apiKey.Group.ID gid := apiKey.Group.ID
cost, err = s.billingService.CalculateCostUnified(CostInput{ cost, err = s.billingService.CalculateCostUnified(CostInput{
Ctx: ctx, Ctx: ctx,
...@@ -8243,6 +7690,7 @@ func (s *GatewayService) calculateTokenCost( ...@@ -8243,6 +7690,7 @@ func (s *GatewayService) calculateTokenCost(
RequestCount: 1, RequestCount: 1,
RateMultiplier: multiplier, RateMultiplier: multiplier,
Resolver: s.resolver, Resolver: s.resolver,
Resolved: resolved,
}) })
} else if opts.LongContextThreshold > 0 { } else if opts.LongContextThreshold > 0 {
// 长上下文双倍计费(如 Gemini 200K 阈值) // 长上下文双倍计费(如 Gemini 200K 阈值)
...@@ -8330,13 +7778,7 @@ func (s *GatewayService) buildRecordUsageLog( ...@@ -8330,13 +7778,7 @@ func (s *GatewayService) buildRecordUsageLog(
} }
// resolveBillingMode 根据计费结果和请求类型确定计费模式。 // resolveBillingMode 根据计费结果和请求类型确定计费模式。
// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string { func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
isSoraMedia := opts.EnableClaudePath &&
(result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt)
if isSoraMedia {
return nil
}
var mode string var mode string
switch { switch {
case cost != nil && cost.BillingMode != "": case cost != nil && cost.BillingMode != "":
...@@ -8350,9 +7792,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost ...@@ -8350,9 +7792,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
} }
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string { func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" {
return &result.MediaType
}
return nil return nil
} }
...@@ -8559,10 +7998,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -8559,10 +7998,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
return err return err
} }
// 获取代理URL // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递)
proxyURL := "" proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL() if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" {
proxyURL = account.Proxy.URL()
}
} }
// 发送请求 // 发送请求
...@@ -8841,6 +8282,16 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -8841,6 +8282,16 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
} }
} else if account.IsCustomBaseURLEnabled() {
customURL := account.GetCustomBaseURL()
if customURL == "" {
return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID)
}
validatedURL, err := s.validateUpstreamBaseURL(customURL)
if err != nil {
return nil, err
}
targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account)
} }
clientHeaders := http.Header{} clientHeaders := http.Header{}
...@@ -8946,6 +8397,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -8946,6 +8397,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
} }
// 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖
if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" {
if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" {
if parsed := ParseMetadataUserID(uid); parsed != nil {
setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID)
}
}
}
if c != nil && tokenType == "oauth" { if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
} }
...@@ -8967,6 +8427,19 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m ...@@ -8967,6 +8427,19 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m
}) })
} }
// buildCustomRelayURL 构建自定义中继转发 URL
// 在 path 后附加 beta=true 和可选的 proxy 查询参数
func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string {
u := strings.TrimRight(baseURL, "/") + path + "?beta=true"
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
u += "&proxy=" + url.QueryEscape(proxyURL)
}
}
return u
}
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment