Commit 292f25f9 authored by yangjianbo's avatar yangjianbo
Browse files
parents c92e3777 fbb57294
...@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error ...@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error
return nil return nil
} }
func (s *apiKeyCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return nil
}
func (s *apiKeyCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
return nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。 // TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为: // 预期行为:
// - GetKeyAndOwnerID 返回所有者 ID 为 1 // - GetKeyAndOwnerID 返回所有者 ID 为 1
......
...@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up ...@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up
func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error { func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil return nil
} }
func (m *mockAccountRepoForPlatform) ClearError(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil return nil
} }
......
...@@ -11,6 +11,8 @@ import ( ...@@ -11,6 +11,8 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"log/slog"
mathrand "math/rand"
"net/http" "net/http"
"os" "os"
"regexp" "regexp"
...@@ -445,11 +447,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -445,11 +447,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
} }
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制) // metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) { func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
// 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs {
excludedIDsList = append(excludedIDsList, id)
}
slog.Debug("account_scheduling_starting",
"group_id", derefGroupID(groupID),
"model", requestedModel,
"session", shortSessionHash(sessionHash),
"excluded_ids", excludedIDsList)
cfg := s.schedulingConfig() cfg := s.schedulingConfig()
// 提取会话 UUID(用于会话数量限制)
sessionUUID := extractSessionUUID(metadataUserID)
var stickyAccountID int64 var stickyAccountID int64
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
...@@ -475,41 +486,63 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -475,41 +486,63 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
if s.concurrencyService == nil || !cfg.LoadBatchEnabled { if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) // 复制排除列表,用于会话限制拒绝时的重试
if err != nil { localExcluded := make(map[int64]struct{})
return nil, err for k, v := range excludedIDs {
} localExcluded[k] = v
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) for {
if waitingCount < cfg.StickySessionMaxWaiting { account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
// 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位
localExcluded[account.ID] = struct{}{} // 排除此账号
continue // 重新选择
}
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: account, Account: account,
WaitPlan: &AccountWaitPlan{ Acquired: true,
AccountID: account.ID, ReleaseFunc: result.ReleaseFunc,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil }, nil
} }
// 对于等待计划的情况,也需要先检查会话限制
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
localExcluded[account.ID] = struct{}{}
continue
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group) platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
...@@ -625,7 +658,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -625,7 +658,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) { if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
result.ReleaseFunc() // 释放槽位 result.ReleaseFunc() // 释放槽位
// 继续到负载感知选择 // 继续到负载感知选择
} else { } else {
...@@ -643,15 +676,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -643,15 +676,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ // 会话数量限制检查(等待计划也需要占用会话配额)
Account: stickyAccount, if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
WaitPlan: &AccountWaitPlan{ // 会话限制已满,继续到负载感知选择
AccountID: stickyAccountID, } else {
MaxConcurrency: stickyAccount.Concurrency, return &AccountSelectionResult{
Timeout: cfg.StickySessionWaitTimeout, Account: stickyAccount,
MaxWaiting: cfg.StickySessionMaxWaiting, WaitPlan: &AccountWaitPlan{
}, AccountID: stickyAccountID,
}, nil MaxConcurrency: stickyAccount.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
} }
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择 // 粘性账号槽位满且等待队列已满,继续使用负载感知选择
} }
...@@ -714,7 +752,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -714,7 +752,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) { if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue continue
} }
...@@ -732,20 +770,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -732,20 +770,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
} }
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的) // 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的)
acc := routingAvailable[0].account // 遍历找到第一个满足会话限制的账号
if s.debugModelRoutingEnabled() { for _, item := range routingAvailable {
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID) if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
continue // 会话限制已满,尝试下一个
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
return &AccountSelectionResult{
Account: item.account,
WaitPlan: &AccountWaitPlan{
AccountID: item.account.ID,
MaxConcurrency: item.account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
return &AccountSelectionResult{ // 所有路由账号会话限制都已满,继续到 Layer 2 回退
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退 // 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
...@@ -773,7 +817,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -773,7 +817,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
// Session count limit check // Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionUUID) { if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2 result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else { } else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
...@@ -787,15 +831,22 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -787,15 +831,22 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ // 会话数量限制检查(等待计划也需要占用会话配额)
Account: account, // Session count limit check (wait plan also requires session quota)
WaitPlan: &AccountWaitPlan{ if !s.checkAndRegisterSession(ctx, account, sessionHash) {
AccountID: accountID, // 会话限制已满,继续到 Layer 2
MaxConcurrency: account.Concurrency, // Session limit full, continue to Layer 2
Timeout: cfg.StickySessionWaitTimeout, } else {
MaxWaiting: cfg.StickySessionMaxWaiting, return &AccountSelectionResult{
}, Account: account,
}, nil WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
} }
} }
} }
...@@ -845,7 +896,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -845,7 +896,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil { if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok { if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
return result, nil return result, nil
} }
} else { } else {
...@@ -895,7 +946,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -895,7 +946,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) { if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue continue
} }
...@@ -913,8 +964,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -913,8 +964,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
// ============ Layer 3: 兜底排队 ============ // ============ Layer 3: 兜底排队 ============
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth) s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode)
for _, acc := range candidates { for _, acc := range candidates {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
continue // 会话限制已满,尝试下一个账号
}
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: acc, Account: acc,
WaitPlan: &AccountWaitPlan{ WaitPlan: &AccountWaitPlan{
...@@ -928,7 +983,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -928,7 +983,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts") return nil, errors.New("no available accounts")
} }
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) { func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...) ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
...@@ -936,7 +991,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates ...@@ -936,7 +991,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
if !s.checkAndRegisterSession(ctx, acc, sessionUUID) { if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue continue
} }
...@@ -1093,7 +1148,24 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr ...@@ -1093,7 +1148,24 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
if s.schedulerSnapshot != nil { if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil {
slog.Debug("account_scheduling_list_snapshot",
"group_id", derefGroupID(groupID),
"platform", platform,
"use_mixed", useMixed,
"count", len(accounts))
for _, acc := range accounts {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
}
return accounts, useMixed, err
} }
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed { if useMixed {
...@@ -1106,6 +1178,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i ...@@ -1106,6 +1178,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
} }
if err != nil { if err != nil {
slog.Debug("account_scheduling_list_failed",
"group_id", derefGroupID(groupID),
"platform", platform,
"error", err)
return nil, useMixed, err return nil, useMixed, err
} }
filtered := make([]Account, 0, len(accounts)) filtered := make([]Account, 0, len(accounts))
...@@ -1115,6 +1191,20 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i ...@@ -1115,6 +1191,20 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
} }
filtered = append(filtered, acc) filtered = append(filtered, acc)
} }
slog.Debug("account_scheduling_list_mixed",
"group_id", derefGroupID(groupID),
"platform", platform,
"raw_count", len(accounts),
"filtered_count", len(filtered))
for _, acc := range filtered {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
return filtered, useMixed, nil return filtered, useMixed, nil
} }
...@@ -1129,8 +1219,25 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i ...@@ -1129,8 +1219,25 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} }
if err != nil { if err != nil {
slog.Debug("account_scheduling_list_failed",
"group_id", derefGroupID(groupID),
"platform", platform,
"error", err)
return nil, useMixed, err return nil, useMixed, err
} }
slog.Debug("account_scheduling_list_single",
"group_id", derefGroupID(groupID),
"platform", platform,
"count", len(accounts))
for _, acc := range accounts {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
return accounts, useMixed, nil return accounts, useMixed, nil
} }
...@@ -1196,12 +1303,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, ...@@ -1196,12 +1303,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
// 缓存未命中,从数据库查询 // 缓存未命中,从数据库查询
{ {
var startTime time.Time // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
if account.SessionWindowStart != nil { startTime := account.GetCurrentWindowStartTime()
startTime = *account.SessionWindowStart
} else {
startTime = time.Now().Add(-5 * time.Hour)
}
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil { if err != nil {
...@@ -1234,15 +1337,16 @@ checkSchedulability: ...@@ -1234,15 +1337,16 @@ checkSchedulability:
// checkAndRegisterSession 检查并注册会话,用于会话数量限制 // checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号 // 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash)
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话) // 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool { func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool {
// 只检查 Anthropic OAuth/SetupToken 账号 // 只检查 Anthropic OAuth/SetupToken 账号
if !account.IsAnthropicOAuthOrSetupToken() { if !account.IsAnthropicOAuthOrSetupToken() {
return true return true
} }
maxSessions := account.GetMaxSessions() maxSessions := account.GetMaxSessions()
if maxSessions <= 0 || sessionUUID == "" { if maxSessions <= 0 || sessionID == "" {
return true // 未启用会话限制或无会话ID return true // 未启用会话限制或无会话ID
} }
...@@ -1252,7 +1356,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A ...@@ -1252,7 +1356,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout) allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout)
if err != nil { if err != nil {
// 失败开放:缓存错误时允许通过 // 失败开放:缓存错误时允许通过
return true return true
...@@ -1260,18 +1364,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A ...@@ -1260,18 +1364,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
return allowed return allowed
} }
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
// 格式: user_{64位hex}_account__session_{uuid}
func extractSessionUUID(metadataUserID string) string {
if metadataUserID == "" {
return ""
}
if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
return match[1]
}
return ""
}
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil { if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID) return s.schedulerSnapshot.GetAccount(ctx, accountID)
...@@ -1301,6 +1393,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { ...@@ -1301,6 +1393,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
}) })
} }
// sortCandidatesForFallback 根据配置选择排序策略
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
if mode == "random" {
// 先按优先级排序,然后在同优先级内随机打乱
sortAccountsByPriorityOnly(accounts, preferOAuth)
shuffleWithinPriority(accounts)
} else {
// 默认按最后使用时间排序
sortAccountsByPriorityAndLastUsed(accounts, preferOAuth)
}
}
// sortAccountsByPriorityOnly 仅按优先级排序
func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
})
}
// shuffleWithinPriority 在同优先级内随机打乱顺序
func shuffleWithinPriority(accounts []*Account) {
if len(accounts) <= 1 {
return
}
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
start := 0
for start < len(accounts) {
priority := accounts[start].Priority
end := start + 1
for end < len(accounts) && accounts[end].Priority == priority {
end++
}
// 对 [start, end) 范围内的账户随机打乱
if end-start > 1 {
r.Shuffle(end-start, func(i, j int) {
accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i]
})
}
start = end
}
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) // selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini preferOAuth := platform == PlatformGemini
...@@ -2158,6 +2300,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2158,6 +2300,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
// 调试日志:记录即将转发的账号信息
log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
retryStart := time.Now() retryStart := time.Now()
...@@ -2172,7 +2318,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2172,7 +2318,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} }
// 发送请求 // 发送请求
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil { if err != nil {
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
_ = resp.Body.Close() _ = resp.Body.Close()
...@@ -2246,7 +2392,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2246,7 +2392,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody := FilterThinkingBlocksForRetry(body) filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil { if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil { if retryErr == nil {
if retryResp.StatusCode < 400 { if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID) log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
...@@ -2278,7 +2424,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2278,7 +2424,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel) retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
if buildErr2 == nil { if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency) retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr2 == nil { if retryErr2 == nil {
resp = retryResp2 resp = retryResp2
break break
...@@ -2393,6 +2539,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2393,6 +2539,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close() _ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody)) resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 调试日志:打印重试耗尽后的错误响应
log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleRetryExhaustedSideEffects(ctx, resp, account) s.handleRetryExhaustedSideEffects(ctx, resp, account)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
...@@ -2420,6 +2570,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2420,6 +2570,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close() _ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody)) resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 调试日志:打印上游错误响应
log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleFailoverSideEffects(ctx, resp, account) s.handleFailoverSideEffects(ctx, resp, account)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
...@@ -2549,9 +2703,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -2549,9 +2703,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
fingerprint = fp fingerprint = fp
// 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid)
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
accountUUID := account.GetExtraString("account_uuid") accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" { if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody body = newBody
} }
} }
...@@ -2770,6 +2925,10 @@ func extractUpstreamErrorMessage(body []byte) string { ...@@ -2770,6 +2925,10 @@ func extractUpstreamErrorMessage(body []byte) string {
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 调试日志:打印上游错误响应
log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
...@@ -3478,7 +3637,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -3478,7 +3637,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// 发送请求 // 发送请求
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil { if err != nil {
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
...@@ -3500,7 +3659,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -3500,7 +3659,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
filteredBody := FilterThinkingBlocksForRetry(body) filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil { if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil { if retryErr == nil {
resp = retryResp resp = retryResp
respBody, err = io.ReadAll(resp.Body) respBody, err = io.ReadAll(resp.Body)
...@@ -3578,12 +3737,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -3578,12 +3737,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
// OAuth 账号:应用统一指纹和重写 userID // OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if account.IsOAuth() && s.identityService != nil { if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil { if err == nil {
accountUUID := account.GetExtraString("account_uuid") accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" { if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody body = newBody
} }
} }
......
...@@ -90,6 +90,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda ...@@ -90,6 +90,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error { func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil return nil
} }
func (m *mockAccountRepoForGemini) ClearError(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil return nil
} }
......
...@@ -10,6 +10,7 @@ import "net/http" ...@@ -10,6 +10,7 @@ import "net/http"
// - 支持可选代理配置 // - 支持可选代理配置
// - 支持账户级连接池隔离 // - 支持账户级连接池隔离
// - 实现类负责连接池管理和复用 // - 实现类负责连接池管理和复用
// - 支持可选的 TLS 指纹伪装
type HTTPUpstream interface { type HTTPUpstream interface {
// Do 执行 HTTP 请求 // Do 执行 HTTP 请求
// //
...@@ -27,4 +28,28 @@ type HTTPUpstream interface { ...@@ -27,4 +28,28 @@ type HTTPUpstream interface {
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏 // - 调用方必须关闭 resp.Body,否则会导致连接泄漏
// - 响应体可能已被包装以跟踪请求生命周期 // - 响应体可能已被包装以跟踪请求生命周期
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
//
// 参数:
// - req: HTTP 请求对象,由调用方构建
// - proxyURL: 代理服务器地址,空字符串表示直连
// - accountID: 账户 ID,用于连接池隔离和 TLS 指纹模板选择
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
//
// 返回:
// - *http.Response: HTTP 响应,调用方必须关闭 Body
// - error: 请求错误(网络错误、超时等)
//
// TLS 指纹说明:
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
// - TLS 指纹模板根据 accountID % len(profiles) 自动选择
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
// - 如果 enableTLSFingerprint=false,行为与 Do 方法相同
//
// 注意:
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
// - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响
DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error)
} }
...@@ -8,9 +8,11 @@ import ( ...@@ -8,9 +8,11 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"log/slog"
"net/http" "net/http"
"regexp" "regexp"
"strconv" "strconv"
"strings"
"time" "time"
) )
...@@ -49,6 +51,13 @@ type Fingerprint struct { ...@@ -49,6 +51,13 @@ type Fingerprint struct {
type IdentityCache interface { type IdentityCache interface {
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error) GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
// GetMaskedSessionID 获取固定的会话ID(用于会话ID伪装功能)
// 返回的 sessionID 是一个 UUID 格式的字符串
// 如果不存在或已过期(15分钟无请求),返回空字符串
GetMaskedSessionID(ctx context.Context, accountID int64) (string, error)
// SetMaskedSessionID 设置固定的会话ID,TTL 为 15 分钟
// 每次调用都会刷新 TTL
SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error
} }
// IdentityService 管理OAuth账号的请求身份指纹 // IdentityService 管理OAuth账号的请求身份指纹
...@@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI ...@@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return json.Marshal(reqMap) return json.Marshal(reqMap)
} }
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
// 先执行常规的 RewriteUserID 逻辑
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
if err != nil {
return newBody, err
}
// 检查是否启用会话ID伪装
if !account.IsSessionIDMaskingEnabled() {
return newBody, nil
}
// 解析重写后的 body,提取 user_id
var reqMap map[string]any
if err := json.Unmarshal(newBody, &reqMap); err != nil {
return newBody, nil
}
metadata, ok := reqMap["metadata"].(map[string]any)
if !ok {
return newBody, nil
}
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
return newBody, nil
}
// 查找 _session_ 的位置,替换其后的内容
const sessionMarker = "_session_"
idx := strings.LastIndex(userID, sessionMarker)
if idx == -1 {
return newBody, nil
}
// 获取或生成固定的伪装 session ID
maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID)
if err != nil {
log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err)
return newBody, nil
}
if maskedSessionID == "" {
// 首次或已过期,生成新的伪装 session ID
maskedSessionID = generateRandomUUID()
log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID)
}
// 刷新 TTL(每次请求都刷新,保持 15 分钟有效期)
if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil {
log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err)
}
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID
slog.Debug("session_id_masking_applied",
"account_id", account.ID,
"before", userID,
"after", newUserID,
)
metadata["user_id"] = newUserID
reqMap["metadata"] = metadata
return json.Marshal(reqMap)
}
// generateRandomUUID 生成随机 UUID v4 格式字符串
func generateRandomUUID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
// fallback: 使用时间戳生成
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
b = h[:16]
}
// 设置 UUID v4 版本和变体位
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x",
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
}
// generateClientID 生成64位十六进制客户端ID(32字节随机数) // generateClientID 生成64位十六进制客户端ID(32字节随机数)
func generateClientID() string { func generateClientID() string {
b := make([]byte, 32) b := make([]byte, 32)
......
...@@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return false return false
} }
tempMatched := false // 先尝试临时不可调度规则(401除外)
// 如果匹配成功,直接返回,不执行后续禁用逻辑
if statusCode != 401 { if statusCode != 401 {
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody) if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
return true
}
} }
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if upstreamMsg != "" { if upstreamMsg != "" {
...@@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} }
switch statusCode { switch statusCode {
case 400:
// 只有当错误信息包含 "organization has been disabled" 时才禁用
if strings.Contains(strings.ToLower(upstreamMsg), "organization has been disabled") {
msg := "Organization disabled (400): " + upstreamMsg
s.handleAuthError(ctx, account, msg)
shouldDisable = true
}
// 其他 400 错误(如参数问题)不处理,不禁用账号
case 401: case 401:
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新 // 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
if account.Type == AccountTypeOAuth { if account.Type == AccountTypeOAuth {
...@@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} }
} }
if tempMatched {
return true
}
return shouldDisable return shouldDisable
} }
......
...@@ -38,8 +38,9 @@ type SessionLimitCache interface { ...@@ -38,8 +38,9 @@ type SessionLimitCache interface {
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数 // GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// idleTimeouts: 每个账号的空闲超时时间配置,key 为 accountID;若为 nil 或某账号不在其中,则使用默认超时
// 返回 map[accountID]count,查询失败的账号不在 map 中 // 返回 map[accountID]count,查询失败的账号不在 map 中
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error)
// IsSessionActive 检查特定会话是否活跃(未过期) // IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
......
...@@ -166,11 +166,25 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc ...@@ -166,11 +166,25 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ { for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
newCredentials, err := refresher.Refresh(ctx, account) newCredentials, err := refresher.Refresh(ctx, account)
if err == nil {
// 刷新成功,更新账号credentials // 如果有新凭证,先更新(即使有错误也要保存 token)
if newCredentials != nil {
account.Credentials = newCredentials account.Credentials = newCredentials
if err := s.accountRepo.Update(ctx, account); err != nil { if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
return fmt.Errorf("failed to save credentials: %w", err) return fmt.Errorf("failed to save credentials: %w", saveErr)
}
}
if err == nil {
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
if account.Platform == PlatformAntigravity &&
account.Status == StatusError &&
strings.Contains(account.ErrorMessage, "missing_project_id:") {
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
log.Printf("[TokenRefresh] Failed to clear error status for account %d: %v", account.ID, clearErr)
} else {
log.Printf("[TokenRefresh] Account %d: cleared missing_project_id error", account.ID)
}
} }
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理) // 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth { if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
...@@ -230,6 +244,7 @@ func isNonRetryableRefreshError(err error) bool { ...@@ -230,6 +244,7 @@ func isNonRetryableRefreshError(err error) bool {
"invalid_client", // 客户端配置错误 "invalid_client", // 客户端配置错误
"unauthorized_client", // 客户端未授权 "unauthorized_client", // 客户端未授权
"access_denied", // 访问被拒绝 "access_denied", // 访问被拒绝
"missing_project_id", // 缺少 project_id
} }
for _, needle := range nonRetryable { for _, needle := range nonRetryable {
if strings.Contains(msg, needle) { if strings.Contains(msg, needle) {
......
package service package service
import ( import (
"context"
"database/sql" "database/sql"
"time" "time"
...@@ -196,6 +197,8 @@ func ProvideOpsScheduledReportService( ...@@ -196,6 +197,8 @@ func ProvideOpsScheduledReportService(
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力 // ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator { func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
// Start Pub/Sub subscriber for L1 cache invalidation across instances
apiKeyService.StartAuthCacheInvalidationSubscriber(context.Background())
return apiKeyService return apiKeyService
} }
......
...@@ -401,3 +401,58 @@ sudo systemctl status redis ...@@ -401,3 +401,58 @@ sudo systemctl status redis
2. **Database connection failed**: Check PostgreSQL is running and credentials are correct 2. **Database connection failed**: Check PostgreSQL is running and credentials are correct
3. **Redis connection failed**: Check Redis is running and password is correct 3. **Redis connection failed**: Check Redis is running and password is correct
4. **Permission denied**: Ensure proper file ownership for binary install 4. **Permission denied**: Ensure proper file ownership for binary install
---
## TLS Fingerprint Configuration
Sub2API supports TLS fingerprint simulation to make requests appear as if they come from the official Claude CLI (Node.js client).
### Default Behavior
- Built-in `claude_cli_v2` profile simulates Node.js 20.x + OpenSSL 3.x
- JA3 Hash: `1a28e69016765d92e3b381168d68922c`
- JA4: `t13d5911h1_a33745022dd6_1f22a2ca17c4`
- Profile selection: `accountID % profileCount`
### Configuration
```yaml
gateway:
tls_fingerprint:
enabled: true # Global switch
profiles:
# Simple profile (uses default cipher suites)
profile_1:
name: "Profile 1"
# Profile with custom cipher suites (use compact array format)
profile_2:
name: "Profile 2"
cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196]
curves: [29, 23, 24]
point_formats: [0]
# Another custom profile
profile_3:
name: "Profile 3"
cipher_suites: [4865, 4866, 4867, 49199, 49200]
curves: [29, 23, 24, 25]
```
### Profile Fields
| Field | Type | Description |
|-------|------|-------------|
| `name` | string | Display name (required) |
| `cipher_suites` | []uint16 | Cipher suites in decimal. Empty = default |
| `curves` | []uint16 | Elliptic curves in decimal. Empty = default |
| `point_formats` | []uint8 | EC point formats. Empty = default |
### Common Values Reference
**Cipher Suites (TLS 1.3):** `4865` (AES_128_GCM), `4866` (AES_256_GCM), `4867` (CHACHA20)
**Cipher Suites (TLS 1.2):** `49195`, `49196`, `49199`, `49200` (ECDHE variants)
**Curves:** `29` (X25519), `23` (P-256), `24` (P-384), `25` (P-521)
...@@ -210,6 +210,19 @@ gateway: ...@@ -210,6 +210,19 @@ gateway:
outbox_backlog_rebuild_rows: 10000 outbox_backlog_rebuild_rows: 10000
# 全量重建周期(秒),0 表示禁用 # 全量重建周期(秒),0 表示禁用
full_rebuild_interval_seconds: 300 full_rebuild_interval_seconds: 300
# TLS fingerprint simulation / TLS 指纹伪装
# Default profile "claude_cli_v2" simulates Node.js 20.x
# 默认模板 "claude_cli_v2" 模拟 Node.js 20.x 指纹
tls_fingerprint:
enabled: true
# profiles:
# profile_1:
# name: "Custom Profile 1"
# profile_2:
# name: "Custom Profile 2"
# cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196]
# curves: [29, 23, 24]
# point_formats: [0]
# ============================================================================= # =============================================================================
# API Key Auth Cache Configuration # API Key Auth Cache Configuration
......
...@@ -1191,6 +1191,190 @@ ...@@ -1191,6 +1191,190 @@
</div> </div>
</div> </div>
<!-- Quota Control Section (Anthropic OAuth/SetupToken only) -->
<div
v-if="form.platform === 'anthropic' && accountCategory === 'oauth-based'"
class="border-t border-gray-200 pt-4 dark:border-dark-600 space-y-4"
>
<div class="mb-3">
<h3 class="input-label mb-0 text-base font-semibold">{{ t('admin.accounts.quotaControl.title') }}</h3>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.hint') }}
</p>
</div>
<!-- Window Cost Limit -->
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.windowCost.label') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.windowCost.hint') }}
</p>
</div>
<button
type="button"
@click="windowCostEnabled = !windowCostEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
windowCostEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
windowCostEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
<div v-if="windowCostEnabled" class="grid grid-cols-2 gap-4">
<div>
<label class="input-label">{{ t('admin.accounts.quotaControl.windowCost.limit') }}</label>
<div class="relative">
<span class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-500 dark:text-gray-400">$</span>
<input
v-model.number="windowCostLimit"
type="number"
min="0"
step="1"
class="input pl-7"
:placeholder="t('admin.accounts.quotaControl.windowCost.limitPlaceholder')"
/>
</div>
<p class="input-hint">{{ t('admin.accounts.quotaControl.windowCost.limitHint') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.quotaControl.windowCost.stickyReserve') }}</label>
<div class="relative">
<span class="absolute left-3 top-1/2 -translate-y-1/2 text-gray-500 dark:text-gray-400">$</span>
<input
v-model.number="windowCostStickyReserve"
type="number"
min="0"
step="1"
class="input pl-7"
:placeholder="t('admin.accounts.quotaControl.windowCost.stickyReservePlaceholder')"
/>
</div>
<p class="input-hint">{{ t('admin.accounts.quotaControl.windowCost.stickyReserveHint') }}</p>
</div>
</div>
</div>
<!-- Session Limit -->
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
<div class="mb-3 flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.sessionLimit.label') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.sessionLimit.hint') }}
</p>
</div>
<button
type="button"
@click="sessionLimitEnabled = !sessionLimitEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
sessionLimitEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
sessionLimitEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
<div v-if="sessionLimitEnabled" class="grid grid-cols-2 gap-4">
<div>
<label class="input-label">{{ t('admin.accounts.quotaControl.sessionLimit.maxSessions') }}</label>
<input
v-model.number="maxSessions"
type="number"
min="1"
step="1"
class="input"
:placeholder="t('admin.accounts.quotaControl.sessionLimit.maxSessionsPlaceholder')"
/>
<p class="input-hint">{{ t('admin.accounts.quotaControl.sessionLimit.maxSessionsHint') }}</p>
</div>
<div>
<label class="input-label">{{ t('admin.accounts.quotaControl.sessionLimit.idleTimeout') }}</label>
<div class="relative">
<input
v-model.number="sessionIdleTimeout"
type="number"
min="1"
step="1"
class="input pr-12"
:placeholder="t('admin.accounts.quotaControl.sessionLimit.idleTimeoutPlaceholder')"
/>
<span class="absolute right-3 top-1/2 -translate-y-1/2 text-gray-500 dark:text-gray-400">{{ t('common.minutes') }}</span>
</div>
<p class="input-hint">{{ t('admin.accounts.quotaControl.sessionLimit.idleTimeoutHint') }}</p>
</div>
</div>
</div>
<!-- TLS Fingerprint -->
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.tlsFingerprint.label') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.tlsFingerprint.hint') }}
</p>
</div>
<button
type="button"
@click="tlsFingerprintEnabled = !tlsFingerprintEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
tlsFingerprintEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
tlsFingerprintEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
</div>
<!-- Session ID Masking -->
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.sessionIdMasking.label') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.sessionIdMasking.hint') }}
</p>
</div>
<button
type="button"
@click="sessionIdMaskingEnabled = !sessionIdMaskingEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
sessionIdMaskingEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
sessionIdMaskingEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
</div>
</div>
<div> <div>
<label class="input-label">{{ t('admin.accounts.proxy') }}</label> <label class="input-label">{{ t('admin.accounts.proxy') }}</label>
<ProxySelector v-model="form.proxy_id" :proxies="proxies" /> <ProxySelector v-model="form.proxy_id" :proxies="proxies" />
...@@ -1214,7 +1398,7 @@ ...@@ -1214,7 +1398,7 @@
</div> </div>
<div> <div>
<label class="input-label">{{ t('admin.accounts.billingRateMultiplier') }}</label> <label class="input-label">{{ t('admin.accounts.billingRateMultiplier') }}</label>
<input v-model.number="form.rate_multiplier" type="number" min="0" step="0.01" class="input" /> <input v-model.number="form.rate_multiplier" type="number" min="0" step="0.001" class="input" />
<p class="input-hint">{{ t('admin.accounts.billingRateMultiplierHint') }}</p> <p class="input-hint">{{ t('admin.accounts.billingRateMultiplierHint') }}</p>
</div> </div>
</div> </div>
...@@ -1763,6 +1947,16 @@ const geminiAIStudioOAuthEnabled = ref(false) ...@@ -1763,6 +1947,16 @@ const geminiAIStudioOAuthEnabled = ref(false)
const showAdvancedOAuth = ref(false) const showAdvancedOAuth = ref(false)
const showGeminiHelpDialog = ref(false) const showGeminiHelpDialog = ref(false)
// Quota control state (Anthropic OAuth/SetupToken only)
const windowCostEnabled = ref(false)
const windowCostLimit = ref<number | null>(null)
const windowCostStickyReserve = ref<number | null>(null)
const sessionLimitEnabled = ref(false)
const maxSessions = ref<number | null>(null)
const sessionIdleTimeout = ref<number | null>(null)
const tlsFingerprintEnabled = ref(false)
const sessionIdMaskingEnabled = ref(false)
// Gemini tier selection (used as fallback when auto-detection is unavailable/fails) // Gemini tier selection (used as fallback when auto-detection is unavailable/fails)
const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free') const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free')
const geminiTierGcp = ref<'gcp_standard' | 'gcp_enterprise'>('gcp_standard') const geminiTierGcp = ref<'gcp_standard' | 'gcp_enterprise'>('gcp_standard')
...@@ -2140,6 +2334,15 @@ const resetForm = () => { ...@@ -2140,6 +2334,15 @@ const resetForm = () => {
customErrorCodeInput.value = null customErrorCodeInput.value = null
interceptWarmupRequests.value = false interceptWarmupRequests.value = false
autoPauseOnExpired.value = true autoPauseOnExpired.value = true
// Reset quota control state
windowCostEnabled.value = false
windowCostLimit.value = null
windowCostStickyReserve.value = null
sessionLimitEnabled.value = false
maxSessions.value = null
sessionIdleTimeout.value = null
tlsFingerprintEnabled.value = false
sessionIdMaskingEnabled.value = false
tempUnschedEnabled.value = false tempUnschedEnabled.value = false
tempUnschedRules.value = [] tempUnschedRules.value = []
geminiOAuthType.value = 'code_assist' geminiOAuthType.value = 'code_assist'
...@@ -2407,7 +2610,32 @@ const handleAnthropicExchange = async (authCode: string) => { ...@@ -2407,7 +2610,32 @@ const handleAnthropicExchange = async (authCode: string) => {
...proxyConfig ...proxyConfig
}) })
const extra = oauth.buildExtraInfo(tokenInfo) // Build extra with quota control settings
const baseExtra = oauth.buildExtraInfo(tokenInfo) || {}
const extra: Record<string, unknown> = { ...baseExtra }
// Add window cost limit settings
if (windowCostEnabled.value && windowCostLimit.value != null && windowCostLimit.value > 0) {
extra.window_cost_limit = windowCostLimit.value
extra.window_cost_sticky_reserve = windowCostStickyReserve.value ?? 10
}
// Add session limit settings
if (sessionLimitEnabled.value && maxSessions.value != null && maxSessions.value > 0) {
extra.max_sessions = maxSessions.value
extra.session_idle_timeout_minutes = sessionIdleTimeout.value ?? 5
}
// Add TLS fingerprint settings
if (tlsFingerprintEnabled.value) {
extra.enable_tls_fingerprint = true
}
// Add session ID masking settings
if (sessionIdMaskingEnabled.value) {
extra.session_id_masking_enabled = true
}
const credentials = { const credentials = {
...tokenInfo, ...tokenInfo,
...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {}) ...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {})
...@@ -2475,7 +2703,32 @@ const handleCookieAuth = async (sessionKey: string) => { ...@@ -2475,7 +2703,32 @@ const handleCookieAuth = async (sessionKey: string) => {
...proxyConfig ...proxyConfig
}) })
const extra = oauth.buildExtraInfo(tokenInfo) // Build extra with quota control settings
const baseExtra = oauth.buildExtraInfo(tokenInfo) || {}
const extra: Record<string, unknown> = { ...baseExtra }
// Add window cost limit settings
if (windowCostEnabled.value && windowCostLimit.value != null && windowCostLimit.value > 0) {
extra.window_cost_limit = windowCostLimit.value
extra.window_cost_sticky_reserve = windowCostStickyReserve.value ?? 10
}
// Add session limit settings
if (sessionLimitEnabled.value && maxSessions.value != null && maxSessions.value > 0) {
extra.max_sessions = maxSessions.value
extra.session_idle_timeout_minutes = sessionIdleTimeout.value ?? 5
}
// Add TLS fingerprint settings
if (tlsFingerprintEnabled.value) {
extra.enable_tls_fingerprint = true
}
// Add session ID masking settings
if (sessionIdMaskingEnabled.value) {
extra.session_id_masking_enabled = true
}
const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name
// Merge interceptWarmupRequests into credentials // Merge interceptWarmupRequests into credentials
......
...@@ -566,7 +566,7 @@ ...@@ -566,7 +566,7 @@
</div> </div>
<div> <div>
<label class="input-label">{{ t('admin.accounts.billingRateMultiplier') }}</label> <label class="input-label">{{ t('admin.accounts.billingRateMultiplier') }}</label>
<input v-model.number="form.rate_multiplier" type="number" min="0" step="0.01" class="input" /> <input v-model.number="form.rate_multiplier" type="number" min="0" step="0.001" class="input" />
<p class="input-hint">{{ t('admin.accounts.billingRateMultiplierHint') }}</p> <p class="input-hint">{{ t('admin.accounts.billingRateMultiplierHint') }}</p>
</div> </div>
</div> </div>
...@@ -732,6 +732,60 @@ ...@@ -732,6 +732,60 @@
</div> </div>
</div> </div>
</div> </div>
<!-- TLS Fingerprint -->
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.tlsFingerprint.label') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.tlsFingerprint.hint') }}
</p>
</div>
<button
type="button"
@click="tlsFingerprintEnabled = !tlsFingerprintEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
tlsFingerprintEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
tlsFingerprintEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
</div>
<!-- Session ID Masking -->
<div class="rounded-lg border border-gray-200 p-4 dark:border-dark-600">
<div class="flex items-center justify-between">
<div>
<label class="input-label mb-0">{{ t('admin.accounts.quotaControl.sessionIdMasking.label') }}</label>
<p class="mt-1 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.accounts.quotaControl.sessionIdMasking.hint') }}
</p>
</div>
<button
type="button"
@click="sessionIdMaskingEnabled = !sessionIdMaskingEnabled"
:class="[
'relative inline-flex h-6 w-11 flex-shrink-0 cursor-pointer rounded-full border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none focus:ring-2 focus:ring-primary-500 focus:ring-offset-2',
sessionIdMaskingEnabled ? 'bg-primary-600' : 'bg-gray-200 dark:bg-dark-600'
]"
>
<span
:class="[
'pointer-events-none inline-block h-5 w-5 transform rounded-full bg-white shadow ring-0 transition duration-200 ease-in-out',
sessionIdMaskingEnabled ? 'translate-x-5' : 'translate-x-0'
]"
/>
</button>
</div>
</div>
</div> </div>
<div class="border-t border-gray-200 pt-4 dark:border-dark-600"> <div class="border-t border-gray-200 pt-4 dark:border-dark-600">
...@@ -904,6 +958,8 @@ const windowCostStickyReserve = ref<number | null>(null) ...@@ -904,6 +958,8 @@ const windowCostStickyReserve = ref<number | null>(null)
const sessionLimitEnabled = ref(false) const sessionLimitEnabled = ref(false)
const maxSessions = ref<number | null>(null) const maxSessions = ref<number | null>(null)
const sessionIdleTimeout = ref<number | null>(null) const sessionIdleTimeout = ref<number | null>(null)
const tlsFingerprintEnabled = ref(false)
const sessionIdMaskingEnabled = ref(false)
// Computed: current preset mappings based on platform // Computed: current preset mappings based on platform
const presetMappings = computed(() => getPresetMappingsByPlatform(props.account?.platform || 'anthropic')) const presetMappings = computed(() => getPresetMappingsByPlatform(props.account?.platform || 'anthropic'))
...@@ -1237,6 +1293,8 @@ function loadQuotaControlSettings(account: Account) { ...@@ -1237,6 +1293,8 @@ function loadQuotaControlSettings(account: Account) {
sessionLimitEnabled.value = false sessionLimitEnabled.value = false
maxSessions.value = null maxSessions.value = null
sessionIdleTimeout.value = null sessionIdleTimeout.value = null
tlsFingerprintEnabled.value = false
sessionIdMaskingEnabled.value = false
// Only applies to Anthropic OAuth/SetupToken accounts // Only applies to Anthropic OAuth/SetupToken accounts
if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) { if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) {
...@@ -1255,6 +1313,16 @@ function loadQuotaControlSettings(account: Account) { ...@@ -1255,6 +1313,16 @@ function loadQuotaControlSettings(account: Account) {
maxSessions.value = account.max_sessions maxSessions.value = account.max_sessions
sessionIdleTimeout.value = account.session_idle_timeout_minutes ?? 5 sessionIdleTimeout.value = account.session_idle_timeout_minutes ?? 5
} }
// Load TLS fingerprint setting
if (account.enable_tls_fingerprint === true) {
tlsFingerprintEnabled.value = true
}
// Load session ID masking setting
if (account.session_id_masking_enabled === true) {
sessionIdMaskingEnabled.value = true
}
} }
function formatTempUnschedKeywords(value: unknown) { function formatTempUnschedKeywords(value: unknown) {
...@@ -1407,6 +1475,20 @@ const handleSubmit = async () => { ...@@ -1407,6 +1475,20 @@ const handleSubmit = async () => {
delete newExtra.session_idle_timeout_minutes delete newExtra.session_idle_timeout_minutes
} }
// TLS fingerprint setting
if (tlsFingerprintEnabled.value) {
newExtra.enable_tls_fingerprint = true
} else {
delete newExtra.enable_tls_fingerprint
}
// Session ID masking setting
if (sessionIdMaskingEnabled.value) {
newExtra.session_id_masking_enabled = true
} else {
delete newExtra.session_id_masking_enabled
}
updatePayload.extra = newExtra updatePayload.extra = newExtra
} }
......
...@@ -21,8 +21,20 @@ ...@@ -21,8 +21,20 @@
</div> </div>
</div> </div>
<!-- Right: Language + Subscriptions + Balance + User Dropdown --> <!-- Right: Docs + Language + Subscriptions + Balance + User Dropdown -->
<div class="flex items-center gap-3"> <div class="flex items-center gap-3">
<!-- Docs Link -->
<a
v-if="docUrl"
:href="docUrl"
target="_blank"
rel="noopener noreferrer"
class="flex items-center gap-1.5 rounded-lg px-2.5 py-1.5 text-sm font-medium text-gray-600 transition-colors hover:bg-gray-100 hover:text-gray-900 dark:text-dark-400 dark:hover:bg-dark-800 dark:hover:text-white"
>
<Icon name="book" size="sm" />
<span class="hidden sm:inline">{{ t('nav.docs') }}</span>
</a>
<!-- Language Switcher --> <!-- Language Switcher -->
<LocaleSwitcher /> <LocaleSwitcher />
...@@ -211,6 +223,7 @@ const user = computed(() => authStore.user) ...@@ -211,6 +223,7 @@ const user = computed(() => authStore.user)
const dropdownOpen = ref(false) const dropdownOpen = ref(false)
const dropdownRef = ref<HTMLElement | null>(null) const dropdownRef = ref<HTMLElement | null>(null)
const contactInfo = computed(() => appStore.contactInfo) const contactInfo = computed(() => appStore.contactInfo)
const docUrl = computed(() => appStore.docUrl)
// 只在标准模式的管理员下显示新手引导按钮 // 只在标准模式的管理员下显示新手引导按钮
const showOnboardingButton = computed(() => { const showOnboardingButton = computed(() => {
......
...@@ -196,7 +196,8 @@ export default { ...@@ -196,7 +196,8 @@ export default {
expand: 'Expand', expand: 'Expand',
logout: 'Logout', logout: 'Logout',
github: 'GitHub', github: 'GitHub',
mySubscriptions: 'My Subscriptions' mySubscriptions: 'My Subscriptions',
docs: 'Docs'
}, },
// Auth // Auth
...@@ -1288,6 +1289,14 @@ export default { ...@@ -1288,6 +1289,14 @@ export default {
idleTimeout: 'Idle Timeout', idleTimeout: 'Idle Timeout',
idleTimeoutPlaceholder: '5', idleTimeoutPlaceholder: '5',
idleTimeoutHint: 'Sessions will be released after idle timeout' idleTimeoutHint: 'Sessions will be released after idle timeout'
},
tlsFingerprint: {
label: 'TLS Fingerprint Simulation',
hint: 'Simulate Node.js/Claude Code client TLS fingerprint'
},
sessionIdMasking: {
label: 'Session ID Masking',
hint: 'When enabled, fixes the session ID in metadata.user_id for 15 minutes, making upstream think requests come from the same session'
} }
}, },
expired: 'Expired', expired: 'Expired',
......
...@@ -193,7 +193,8 @@ export default { ...@@ -193,7 +193,8 @@ export default {
expand: '展开', expand: '展开',
logout: '退出登录', logout: '退出登录',
github: 'GitHub', github: 'GitHub',
mySubscriptions: '我的订阅' mySubscriptions: '我的订阅',
docs: '文档'
}, },
// Auth // Auth
...@@ -1420,6 +1421,14 @@ export default { ...@@ -1420,6 +1421,14 @@ export default {
idleTimeout: '空闲超时', idleTimeout: '空闲超时',
idleTimeoutPlaceholder: '5', idleTimeoutPlaceholder: '5',
idleTimeoutHint: '会话空闲超时后自动释放' idleTimeoutHint: '会话空闲超时后自动释放'
},
tlsFingerprint: {
label: 'TLS 指纹模拟',
hint: '模拟 Node.js/Claude Code 客户端的 TLS 指纹'
},
sessionIdMasking: {
label: '会话 ID 伪装',
hint: '启用后将在 15 分钟内固定 metadata.user_id 中的 session ID,使上游认为请求来自同一会话'
} }
}, },
expired: '已过期', expired: '已过期',
......
...@@ -480,6 +480,13 @@ export interface Account { ...@@ -480,6 +480,13 @@ export interface Account {
max_sessions?: number | null max_sessions?: number | null
session_idle_timeout_minutes?: number | null session_idle_timeout_minutes?: number | null
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
enable_tls_fingerprint?: boolean | null
// 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效)
// 启用后将在15分钟内固定 metadata.user_id 中的 session ID
session_id_masking_enabled?: boolean | null
// 运行时状态(仅当启用对应限制时返回) // 运行时状态(仅当启用对应限制时返回)
current_window_cost?: number | null // 当前窗口费用 current_window_cost?: number | null // 当前窗口费用
active_sessions?: number | null // 当前活跃会话数 active_sessions?: number | null // 当前活跃会话数
......
...@@ -243,7 +243,7 @@ ...@@ -243,7 +243,7 @@
/> />
<p class="input-hint">{{ t('admin.groups.platformHint') }}</p> <p class="input-hint">{{ t('admin.groups.platformHint') }}</p>
</div> </div>
<div v-if="createForm.subscription_type !== 'subscription'"> <div>
<label class="input-label">{{ t('admin.groups.form.rateMultiplier') }}</label> <label class="input-label">{{ t('admin.groups.form.rateMultiplier') }}</label>
<input <input
v-model.number="createForm.rate_multiplier" v-model.number="createForm.rate_multiplier"
...@@ -680,7 +680,7 @@ ...@@ -680,7 +680,7 @@
/> />
<p class="input-hint">{{ t('admin.groups.platformNotEditable') }}</p> <p class="input-hint">{{ t('admin.groups.platformNotEditable') }}</p>
</div> </div>
<div v-if="editForm.subscription_type !== 'subscription'"> <div>
<label class="input-label">{{ t('admin.groups.form.rateMultiplier') }}</label> <label class="input-label">{{ t('admin.groups.form.rateMultiplier') }}</label>
<input <input
v-model.number="editForm.rate_multiplier" v-model.number="editForm.rate_multiplier"
...@@ -1605,12 +1605,11 @@ const confirmDelete = async () => { ...@@ -1605,12 +1605,11 @@ const confirmDelete = async () => {
} }
} }
// 监听 subscription_type 变化,订阅模式时重置 rate_multiplier 为 1,is_exclusive 为 true // 监听 subscription_type 变化,订阅模式时 is_exclusive 默认为 true
watch( watch(
() => createForm.subscription_type, () => createForm.subscription_type,
(newVal) => { (newVal) => {
if (newVal === 'subscription') { if (newVal === 'subscription') {
createForm.rate_multiplier = 1.0
createForm.is_exclusive = true createForm.is_exclusive = true
} }
} }
......
...@@ -21,5 +21,6 @@ ...@@ -21,5 +21,6 @@
"types": ["vite/client"] "types": ["vite/client"]
}, },
"include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"], "include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"],
"exclude": ["src/**/__tests__/**", "src/**/*.spec.ts", "src/**/*.test.ts"],
"references": [{ "path": "./tsconfig.node.json" }] "references": [{ "path": "./tsconfig.node.json" }]
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment