"frontend/src/vscode:/vscode.git/clone" did not exist on "29d58f2414f06c74f49fe8992b01a5ca76b7ad2a"
Commit f345b0f5 authored by erio's avatar erio
Browse files

fix: use upstream versions of shared files and remove only Sora code

Restore gateway_service.go, setting_handler.go, routes/admin.go,
dto/settings.go, group_repo.go, api_key_repo.go, wire_gen.go to
upstream/main versions and surgically remove only Sora references.

This preserves upstream-only features (RequireOauthOnly, RequirePrivacySet,
GroupResolution, etc.) that were missing when using release branch versions.
parent 58707f8a
...@@ -651,6 +651,8 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -651,6 +651,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
SupportedModelScopes: g.SupportedModelScopes, SupportedModelScopes: g.SupportedModelScopes,
SortOrder: g.SortOrder, SortOrder: g.SortOrder,
AllowMessagesDispatch: g.AllowMessagesDispatch, AllowMessagesDispatch: g.AllowMessagesDispatch,
RequireOAuthOnly: g.RequireOauthOnly,
RequirePrivacySet: g.RequirePrivacySet,
DefaultMappedModel: g.DefaultMappedModel, DefaultMappedModel: g.DefaultMappedModel,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
......
...@@ -56,6 +56,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -56,6 +56,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject). SetMcpXMLInject(groupIn.MCPXMLInject).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel) SetDefaultMappedModel(groupIn.DefaultMappedModel)
// 设置模型路由配置 // 设置模型路由配置
...@@ -120,6 +122,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -120,6 +122,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject). SetMcpXMLInject(groupIn.MCPXMLInject).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel) SetDefaultMappedModel(groupIn.DefaultMappedModel)
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
......
...@@ -60,6 +60,7 @@ const ( ...@@ -60,6 +60,7 @@ const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info" claudeMimicDebugInfoKey = "claude_mimic_debug_info"
) )
// ForceCacheBillingContextKey 强制缓存计费上下文键 // ForceCacheBillingContextKey 强制缓存计费上下文键
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
type forceCacheBillingKeyType struct{} type forceCacheBillingKeyType struct{}
...@@ -503,6 +504,7 @@ type ForwardResult struct { ...@@ -503,6 +504,7 @@ type ForwardResult struct {
// 图片生成计费字段(图片生成模型使用) // 图片生成计费字段(图片生成模型使用)
ImageCount int // 生成的图片数量 ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K" ImageSize string // 图片尺寸 "1K", "2K", "4K"
} }
// UpstreamFailoverError indicates an upstream error that should trigger account failover. // UpstreamFailoverError indicates an upstream error that should trigger account failover.
...@@ -1330,11 +1332,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1330,11 +1332,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
ctx = s.withWindowCostPrefetch(ctx, accounts) ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts)
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
isExcluded := func(accountID int64) bool { isExcluded := func(accountID int64) bool {
if excludedIDs == nil { if excludedIDs == nil {
return false return false
...@@ -1343,6 +1340,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1343,6 +1340,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return excluded return excluded
} }
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
// 获取模型路由配置(仅 anthropic 平台) // 获取模型路由配置(仅 anthropic 平台)
var routingAccountIDs []int64 var routingAccountIDs []int64
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic { if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
...@@ -1430,19 +1433,24 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1430,19 +1433,24 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
// 粘性账号在路由列表中,优先使用 // 粘性账号在路由列表中,优先使用
if stickyAccount, ok := accountByID[stickyAccountID]; ok { if stickyAccount, ok := accountByID[stickyAccountID]; ok {
if s.isAccountSchedulableForSelection(stickyAccount) && var stickyCacheMissReason string
gatePass := s.isAccountSchedulableForSelection(stickyAccount) &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
s.isAccountSchedulableForQuota(stickyAccount) && s.isAccountSchedulableForQuota(stickyAccount) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true)
rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true)
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 if rpmPass { // 粘性会话窗口费用+RPM 检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
result.ReleaseFunc() // 释放槽位 result.ReleaseFunc() // 释放槽位
stickyCacheMissReason = "session_limit"
// 继续到负载感知选择 // 继续到负载感知选择
} else { } else {
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
...@@ -1456,10 +1464,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1456,10 +1464,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
} }
if stickyCacheMissReason == "" {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额) // 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
stickyCacheMissReason = "session_limit"
// 会话限制已满,继续到负载感知选择 // 会话限制已满,继续到负载感知选择
} else { } else {
return &AccountSelectionResult{ return &AccountSelectionResult{
...@@ -1472,11 +1482,31 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1472,11 +1482,31 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}, },
}, nil }, nil
} }
} else {
stickyCacheMissReason = "wait_queue_full"
}
} }
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择 // 粘性账号槽位满且等待队列已满,继续使用负载感知选择
} else if !gatePass {
stickyCacheMissReason = "gate_check"
} else {
stickyCacheMissReason = "rpm_red"
}
// 记录粘性缓存未命中的结构化日志
if stickyCacheMissReason != "" {
baseRPM := stickyAccount.GetBaseRPM()
var currentRPM int
if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok {
currentRPM = count
}
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d",
stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM)
} }
} else { } else {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0",
stickyAccountID, shortSessionHash(sessionHash))
} }
} }
} }
...@@ -1582,6 +1612,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1582,6 +1612,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
account, ok := accountByID[accountID] account, ok := accountByID[accountID]
if ok { if ok {
// 检查账户是否需要清理粘性会话绑定 // 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account, requestedModel) clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky { if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
...@@ -1597,6 +1628,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1597,6 +1628,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
// Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionHash) { if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2 result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else { } else {
...@@ -1611,8 +1643,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1611,8 +1643,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额) // 会话数量限制检查(等待计划也需要占用会话配额)
// Session count limit check (wait plan also requires session quota)
if !s.checkAndRegisterSession(ctx, account, sessionHash) { if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2 // 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
} else { } else {
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: account, Account: account,
...@@ -2673,6 +2707,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -2673,6 +2707,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
preferOAuth := platform == PlatformGemini preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
// require_privacy_set: 获取分组信息
var schedGroup *Group
if groupID != nil && s.groupRepo != nil {
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
}
var accounts []Account var accounts []Account
accountsLoaded := false accountsLoaded := false
...@@ -2696,7 +2736,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -2696,7 +2736,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky { if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) { if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
} }
...@@ -2744,6 +2784,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -2744,6 +2784,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) { if !s.isAccountSchedulableForSelection(acc) {
continue continue
} }
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue continue
} }
...@@ -2849,6 +2895,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -2849,6 +2895,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForSelection(acc) { if !s.isAccountSchedulableForSelection(acc) {
continue continue
} }
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
continue continue
} }
...@@ -2915,6 +2967,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2915,6 +2967,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
preferOAuth := nativePlatform == PlatformGemini preferOAuth := nativePlatform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
// require_privacy_set: 获取分组信息
var schedGroup *Group
if groupID != nil && s.groupRepo != nil {
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
}
var accounts []Account var accounts []Account
accountsLoaded := false accountsLoaded := false
...@@ -2982,6 +3040,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -2982,6 +3040,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) { if !s.isAccountSchedulableForSelection(acc) {
continue continue
} }
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度 // 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue continue
...@@ -3051,7 +3115,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -3051,7 +3115,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky { if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) { if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil return account, nil
} }
...@@ -3075,6 +3139,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -3075,6 +3139,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
ctx = s.withRPMPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts)
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var selected *Account var selected *Account
for i := range accounts { for i := range accounts {
...@@ -3087,6 +3152,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -3087,6 +3152,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForSelection(acc) { if !s.isAccountSchedulableForSelection(acc) {
continue continue
} }
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
_ = s.accountRepo.SetError(ctx, acc.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度 // 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue continue
...@@ -3254,8 +3325,7 @@ func (s *GatewayService) diagnoseSelectionFailure( ...@@ -3254,8 +3325,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "excluded"} return selectionFailureDiagnosis{Category: "excluded"}
} }
if !s.isAccountSchedulableForSelection(acc) { if !s.isAccountSchedulableForSelection(acc) {
detail := "generic_unschedulable" return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
} }
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
return selectionFailureDiagnosis{ return selectionFailureDiagnosis{
...@@ -3279,7 +3349,6 @@ func (s *GatewayService) diagnoseSelectionFailure( ...@@ -3279,7 +3349,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
return selectionFailureDiagnosis{Category: "eligible"} return selectionFailureDiagnosis{Category: "eligible"}
} }
// GetAccessToken 获取账号凭证
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
if acc == nil { if acc == nil {
return true return true
...@@ -3362,10 +3431,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo ...@@ -3362,10 +3431,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
_, ok := ResolveBedrockModelID(account, requestedModel) _, ok := ResolveBedrockModelID(account, requestedModel)
return ok return ok
} }
// OpenAI 透传模式:仅替换认证,允许所有模型
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
return true
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel) requestedModel = claude.NormalizeModelID(requestedModel)
...@@ -7083,7 +7148,6 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, ...@@ -7083,7 +7148,6 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct { type RecordUsageInput struct {
Result *ForwardResult Result *ForwardResult
ParsedRequest *ParsedRequest
APIKey *APIKey APIKey *APIKey
User *User User *User
Account *Account Account *Account
...@@ -7242,9 +7306,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage ...@@ -7242,9 +7306,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
cmd.CacheCreationTokens = usageLog.CacheCreationTokens cmd.CacheCreationTokens = usageLog.CacheCreationTokens
cmd.CacheReadTokens = usageLog.CacheReadTokens cmd.CacheReadTokens = usageLog.CacheReadTokens
cmd.ImageCount = usageLog.ImageCount cmd.ImageCount = usageLog.ImageCount
if usageLog.MediaType != nil {
cmd.MediaType = *usageLog.MediaType
}
if usageLog.ServiceTier != nil { if usageLog.ServiceTier != nil {
cmd.ServiceTier = *usageLog.ServiceTier cmd.ServiceTier = *usageLog.ServiceTier
} }
...@@ -7395,11 +7456,11 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage ...@@ -7395,11 +7456,11 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。 // recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
type recordUsageOpts struct { type recordUsageOpts struct {
// ParsedRequest(可选,仅 Claude 路径传入) // Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入)
ParsedRequest *ParsedRequest ParsedRequest *ParsedRequest
// EnableClaudePath 启用 Claude 路径特有逻辑: // EnableClaudePath 启用 Claude 路径特有逻辑:
// - MediaType 字段写入使用日志 // - Claude Max 缓存计费策略
EnableClaudePath bool EnableClaudePath bool
// 长上下文计费(仅 Gemini 路径需要) // 长上下文计费(仅 Gemini 路径需要)
...@@ -7424,7 +7485,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -7424,7 +7485,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
APIKeyService: input.APIKeyService, APIKeyService: input.APIKeyService,
ChannelUsageFields: input.ChannelUsageFields, ChannelUsageFields: input.ChannelUsageFields,
}, &recordUsageOpts{ }, &recordUsageOpts{
ParsedRequest: input.ParsedRequest,
EnableClaudePath: true, EnableClaudePath: true,
}) })
} }
...@@ -7490,6 +7550,7 @@ type recordUsageCoreInput struct { ...@@ -7490,6 +7550,7 @@ type recordUsageCoreInput struct {
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 // recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
// opts 中的字段控制两者之间的差异行为: // opts 中的字段控制两者之间的差异行为:
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext // - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
result := input.Result result := input.Result
...@@ -7748,13 +7809,12 @@ func (s *GatewayService) buildRecordUsageLog( ...@@ -7748,13 +7809,12 @@ func (s *GatewayService) buildRecordUsageLog(
RateMultiplier: multiplier, RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier, AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType, BillingType: billingType,
BillingMode: resolveBillingMode(opts, result, cost), BillingMode: resolveBillingMode(result, cost),
Stream: result.Stream, Stream: result.Stream,
DurationMs: &durationMs, DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs, FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount, ImageCount: result.ImageCount,
ImageSize: optionalTrimmedStringPtr(result.ImageSize), ImageSize: optionalTrimmedStringPtr(result.ImageSize),
MediaType: resolveMediaType(opts, result),
CacheTTLOverridden: cacheTTLOverridden, CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID), ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
...@@ -7778,7 +7838,7 @@ func (s *GatewayService) buildRecordUsageLog( ...@@ -7778,7 +7838,7 @@ func (s *GatewayService) buildRecordUsageLog(
} }
// resolveBillingMode 根据计费结果和请求类型确定计费模式。 // resolveBillingMode 根据计费结果和请求类型确定计费模式。
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string { func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
var mode string var mode string
switch { switch {
case cost != nil && cost.BillingMode != "": case cost != nil && cost.BillingMode != "":
...@@ -7791,10 +7851,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost ...@@ -7791,10 +7851,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
return &mode return &mode
} }
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
return nil
}
func optionalSubscriptionID(subscription *UserSubscription) *int64 { func optionalSubscriptionID(subscription *UserSubscription) *int64 {
if subscription != nil { if subscription != nil {
return &subscription.ID return &subscription.ID
...@@ -7899,19 +7955,6 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex ...@@ -7899,19 +7955,6 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
return ch.BillingModelSource == BillingModelSourceUpstream return ch.BillingModelSource == BillingModelSourceUpstream
} }
// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。
// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用,
// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。
func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool {
if groupID == nil {
return false
}
if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) {
return false
}
return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel)
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API // ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应 // 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment