"backend/vscode:/vscode.git/clone" did not exist on "8e834fd9f53b6cd8cf3735e3a3ae7a66f42b9dc1"
Commit c8e2f614 authored by cyhhao's avatar cyhhao
Browse files

Merge branch 'main' of github.com:Wei-Shaw/sub2api

parents c0347cde c95a8649
...@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start ...@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start
return nil return nil
} }
func (s *dashboardAggregationRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
if s.err != nil { if s.err != nil {
return time.Time{}, s.err return time.Time{}, s.err
......
...@@ -93,13 +93,14 @@ const ( ...@@ -93,13 +93,14 @@ const (
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// OEM设置 // OEM设置
SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入) SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
SettingKeyContactInfo = "contact_info" // 客服联系方式 SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接 SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src) SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
// 默认配置 // 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
......
...@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up ...@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up
func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error { func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil return nil
} }
func (m *mockAccountRepoForPlatform) ClearError(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil return nil
} }
......
...@@ -11,6 +11,8 @@ import ( ...@@ -11,6 +11,8 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"log/slog"
mathrand "math/rand"
"net/http" "net/http"
"os" "os"
"regexp" "regexp"
...@@ -819,11 +821,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -819,11 +821,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
} }
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制) // metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) { func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
// 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs {
excludedIDsList = append(excludedIDsList, id)
}
slog.Debug("account_scheduling_starting",
"group_id", derefGroupID(groupID),
"model", requestedModel,
"session", shortSessionHash(sessionHash),
"excluded_ids", excludedIDsList)
cfg := s.schedulingConfig() cfg := s.schedulingConfig()
// 提取会话 UUID(用于会话数量限制)
sessionUUID := extractSessionUUID(metadataUserID)
var stickyAccountID int64 var stickyAccountID int64
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
...@@ -849,41 +860,63 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -849,41 +860,63 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
if s.concurrencyService == nil || !cfg.LoadBatchEnabled { if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) // 复制排除列表,用于会话限制拒绝时的重试
if err != nil { localExcluded := make(map[int64]struct{})
return nil, err for k, v := range excludedIDs {
} localExcluded[k] = v
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) for {
if waitingCount < cfg.StickySessionMaxWaiting { account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
// 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位
localExcluded[account.ID] = struct{}{} // 排除此账号
continue // 重新选择
}
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: account, Account: account,
WaitPlan: &AccountWaitPlan{ Acquired: true,
AccountID: account.ID, ReleaseFunc: result.ReleaseFunc,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil }, nil
} }
// 对于等待计划的情况,也需要先检查会话限制
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
localExcluded[account.ID] = struct{}{}
continue
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group) platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
...@@ -999,7 +1032,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -999,7 +1032,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) { if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
result.ReleaseFunc() // 释放槽位 result.ReleaseFunc() // 释放槽位
// 继续到负载感知选择 // 继续到负载感知选择
} else { } else {
...@@ -1017,15 +1050,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1017,15 +1050,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ // 会话数量限制检查(等待计划也需要占用会话配额)
Account: stickyAccount, if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
WaitPlan: &AccountWaitPlan{ // 会话限制已满,继续到负载感知选择
AccountID: stickyAccountID, } else {
MaxConcurrency: stickyAccount.Concurrency, return &AccountSelectionResult{
Timeout: cfg.StickySessionWaitTimeout, Account: stickyAccount,
MaxWaiting: cfg.StickySessionMaxWaiting, WaitPlan: &AccountWaitPlan{
}, AccountID: stickyAccountID,
}, nil MaxConcurrency: stickyAccount.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
} }
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择 // 粘性账号槽位满且等待队列已满,继续使用负载感知选择
} }
...@@ -1086,7 +1124,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1086,7 +1124,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) { if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue continue
} }
...@@ -1104,20 +1142,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1104,20 +1142,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
} }
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的) // 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的)
acc := routingAvailable[0].account // 遍历找到第一个满足会话限制的账号
if s.debugModelRoutingEnabled() { for _, item := range routingAvailable {
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID) if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
continue // 会话限制已满,尝试下一个
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
return &AccountSelectionResult{
Account: item.account,
WaitPlan: &AccountWaitPlan{
AccountID: item.account.ID,
MaxConcurrency: item.account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
return &AccountSelectionResult{ // 所有路由账号会话限制都已满,继续到 Layer 2 回退
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退 // 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
...@@ -1137,7 +1181,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1137,7 +1181,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
if !s.checkAndRegisterSession(ctx, account, sessionUUID) { if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2 result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else { } else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
...@@ -1151,15 +1195,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1151,15 +1195,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ // 会话数量限制检查(等待计划也需要占用会话配额)
Account: account, if !s.checkAndRegisterSession(ctx, account, sessionHash) {
WaitPlan: &AccountWaitPlan{ // 会话限制已满,继续到 Layer 2
AccountID: accountID, } else {
MaxConcurrency: account.Concurrency, return &AccountSelectionResult{
Timeout: cfg.StickySessionWaitTimeout, Account: account,
MaxWaiting: cfg.StickySessionMaxWaiting, WaitPlan: &AccountWaitPlan{
}, AccountID: accountID,
}, nil MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
} }
} }
} }
...@@ -1208,7 +1257,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1208,7 +1257,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil { if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok { if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
return result, nil return result, nil
} }
} else { } else {
...@@ -1258,7 +1307,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1258,7 +1307,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) { if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue continue
} }
...@@ -1276,8 +1325,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1276,8 +1325,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
// ============ Layer 3: 兜底排队 ============ // ============ Layer 3: 兜底排队 ============
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth) s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode)
for _, acc := range candidates { for _, acc := range candidates {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
continue // 会话限制已满,尝试下一个账号
}
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: acc, Account: acc,
WaitPlan: &AccountWaitPlan{ WaitPlan: &AccountWaitPlan{
...@@ -1291,7 +1344,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1291,7 +1344,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts") return nil, errors.New("no available accounts")
} }
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) { func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...) ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
...@@ -1299,7 +1352,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates ...@@ -1299,7 +1352,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
if !s.checkAndRegisterSession(ctx, acc, sessionUUID) { if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue continue
} }
...@@ -1456,7 +1509,24 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr ...@@ -1456,7 +1509,24 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
if s.schedulerSnapshot != nil { if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil {
slog.Debug("account_scheduling_list_snapshot",
"group_id", derefGroupID(groupID),
"platform", platform,
"use_mixed", useMixed,
"count", len(accounts))
for _, acc := range accounts {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
}
return accounts, useMixed, err
} }
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed { if useMixed {
...@@ -1469,6 +1539,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i ...@@ -1469,6 +1539,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
} }
if err != nil { if err != nil {
slog.Debug("account_scheduling_list_failed",
"group_id", derefGroupID(groupID),
"platform", platform,
"error", err)
return nil, useMixed, err return nil, useMixed, err
} }
filtered := make([]Account, 0, len(accounts)) filtered := make([]Account, 0, len(accounts))
...@@ -1478,6 +1552,20 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i ...@@ -1478,6 +1552,20 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
} }
filtered = append(filtered, acc) filtered = append(filtered, acc)
} }
slog.Debug("account_scheduling_list_mixed",
"group_id", derefGroupID(groupID),
"platform", platform,
"raw_count", len(accounts),
"filtered_count", len(filtered))
for _, acc := range filtered {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
return filtered, useMixed, nil return filtered, useMixed, nil
} }
...@@ -1492,8 +1580,25 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i ...@@ -1492,8 +1580,25 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} }
if err != nil { if err != nil {
slog.Debug("account_scheduling_list_failed",
"group_id", derefGroupID(groupID),
"platform", platform,
"error", err)
return nil, useMixed, err return nil, useMixed, err
} }
slog.Debug("account_scheduling_list_single",
"group_id", derefGroupID(groupID),
"platform", platform,
"count", len(accounts))
for _, acc := range accounts {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
return accounts, useMixed, nil return accounts, useMixed, nil
} }
...@@ -1559,12 +1664,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, ...@@ -1559,12 +1664,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
// 缓存未命中,从数据库查询 // 缓存未命中,从数据库查询
{ {
var startTime time.Time // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
if account.SessionWindowStart != nil { startTime := account.GetCurrentWindowStartTime()
startTime = *account.SessionWindowStart
} else {
startTime = time.Now().Add(-5 * time.Hour)
}
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil { if err != nil {
...@@ -1597,15 +1698,16 @@ checkSchedulability: ...@@ -1597,15 +1698,16 @@ checkSchedulability:
// checkAndRegisterSession 检查并注册会话,用于会话数量限制 // checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号 // 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash)
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话) // 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool { func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool {
// 只检查 Anthropic OAuth/SetupToken 账号 // 只检查 Anthropic OAuth/SetupToken 账号
if !account.IsAnthropicOAuthOrSetupToken() { if !account.IsAnthropicOAuthOrSetupToken() {
return true return true
} }
maxSessions := account.GetMaxSessions() maxSessions := account.GetMaxSessions()
if maxSessions <= 0 || sessionUUID == "" { if maxSessions <= 0 || sessionID == "" {
return true // 未启用会话限制或无会话ID return true // 未启用会话限制或无会话ID
} }
...@@ -1615,7 +1717,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A ...@@ -1615,7 +1717,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout) allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout)
if err != nil { if err != nil {
// 失败开放:缓存错误时允许通过 // 失败开放:缓存错误时允许通过
return true return true
...@@ -1623,18 +1725,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A ...@@ -1623,18 +1725,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
return allowed return allowed
} }
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
// 格式: user_{64位hex}_account__session_{uuid}
func extractSessionUUID(metadataUserID string) string {
if metadataUserID == "" {
return ""
}
if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
return match[1]
}
return ""
}
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil { if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID) return s.schedulerSnapshot.GetAccount(ctx, accountID)
...@@ -1664,6 +1754,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { ...@@ -1664,6 +1754,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
}) })
} }
// sortCandidatesForFallback 根据配置选择排序策略
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
if mode == "random" {
// 先按优先级排序,然后在同优先级内随机打乱
sortAccountsByPriorityOnly(accounts, preferOAuth)
shuffleWithinPriority(accounts)
} else {
// 默认按最后使用时间排序
sortAccountsByPriorityAndLastUsed(accounts, preferOAuth)
}
}
// sortAccountsByPriorityOnly 仅按优先级排序
func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
})
}
// shuffleWithinPriority 在同优先级内随机打乱顺序
func shuffleWithinPriority(accounts []*Account) {
if len(accounts) <= 1 {
return
}
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
start := 0
for start < len(accounts) {
priority := accounts[start].Priority
end := start + 1
for end < len(accounts) && accounts[end].Priority == priority {
end++
}
// 对 [start, end) 范围内的账户随机打乱
if end-start > 1 {
r.Shuffle(end-start, func(i, j int) {
accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i]
})
}
start = end
}
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) // selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini preferOAuth := platform == PlatformGemini
...@@ -2524,6 +2664,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2524,6 +2664,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
// 调试日志:记录即将转发的账号信息
log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
// 重试循环 // 重试循环
var resp *http.Response var resp *http.Response
retryStart := time.Now() retryStart := time.Now()
...@@ -2537,7 +2681,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2537,7 +2681,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
} }
// 发送请求 // 发送请求
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil { if err != nil {
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
_ = resp.Body.Close() _ = resp.Body.Close()
...@@ -2611,7 +2755,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2611,7 +2755,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody := FilterThinkingBlocksForRetry(body) filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr == nil { if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil { if retryErr == nil {
if retryResp.StatusCode < 400 { if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID) log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
...@@ -2643,7 +2787,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2643,7 +2787,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr2 == nil { if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency) retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr2 == nil { if retryErr2 == nil {
resp = retryResp2 resp = retryResp2
break break
...@@ -2758,6 +2902,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2758,6 +2902,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close() _ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody)) resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 调试日志:打印重试耗尽后的错误响应
log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleRetryExhaustedSideEffects(ctx, resp, account) s.handleRetryExhaustedSideEffects(ctx, resp, account)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
...@@ -2785,6 +2933,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2785,6 +2933,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close() _ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody)) resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 调试日志:打印上游错误响应
log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleFailoverSideEffects(ctx, resp, account) s.handleFailoverSideEffects(ctx, resp, account)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
...@@ -2914,9 +3066,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -2914,9 +3066,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
fingerprint = fp fingerprint = fp
// 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid)
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
accountUUID := account.GetExtraString("account_uuid") accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" { if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody body = newBody
} }
} }
...@@ -3183,6 +3336,10 @@ func extractUpstreamErrorMessage(body []byte) string { ...@@ -3183,6 +3336,10 @@ func extractUpstreamErrorMessage(body []byte) string {
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 调试日志:打印上游错误响应
log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
...@@ -4171,7 +4328,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -4171,7 +4328,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// 发送请求 // 发送请求
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil { if err != nil {
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
...@@ -4193,7 +4350,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -4193,7 +4350,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
filteredBody := FilterThinkingBlocksForRetry(body) filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode) retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode)
if buildErr == nil { if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil { if retryErr == nil {
resp = retryResp resp = retryResp
respBody, err = io.ReadAll(resp.Body) respBody, err = io.ReadAll(resp.Body)
...@@ -4271,12 +4428,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -4271,12 +4428,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
} }
// OAuth 账号:应用统一指纹和重写 userID // OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if account.IsOAuth() && s.identityService != nil { if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil { if err == nil {
accountUUID := account.GetExtraString("account_uuid") accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" { if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody body = newBody
} }
} }
......
...@@ -88,6 +88,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda ...@@ -88,6 +88,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error { func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil return nil
} }
func (m *mockAccountRepoForGemini) ClearError(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil return nil
} }
...@@ -599,7 +602,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) { ...@@ -599,7 +602,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
name: "Gemini平台-有映射配置-只支持配置的模型", name: "Gemini平台-有映射配置-只支持配置的模型",
account: &Account{ account: &Account{
Platform: PlatformGemini, Platform: PlatformGemini,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}}, Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "x"}},
}, },
model: "gemini-2.5-flash", model: "gemini-2.5-flash",
expected: false, expected: false,
......
...@@ -10,6 +10,7 @@ import "net/http" ...@@ -10,6 +10,7 @@ import "net/http"
// - 支持可选代理配置 // - 支持可选代理配置
// - 支持账户级连接池隔离 // - 支持账户级连接池隔离
// - 实现类负责连接池管理和复用 // - 实现类负责连接池管理和复用
// - 支持可选的 TLS 指纹伪装
type HTTPUpstream interface { type HTTPUpstream interface {
// Do 执行 HTTP 请求 // Do 执行 HTTP 请求
// //
...@@ -27,4 +28,28 @@ type HTTPUpstream interface { ...@@ -27,4 +28,28 @@ type HTTPUpstream interface {
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏 // - 调用方必须关闭 resp.Body,否则会导致连接泄漏
// - 响应体可能已被包装以跟踪请求生命周期 // - 响应体可能已被包装以跟踪请求生命周期
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
//
// 参数:
// - req: HTTP 请求对象,由调用方构建
// - proxyURL: 代理服务器地址,空字符串表示直连
// - accountID: 账户 ID,用于连接池隔离和 TLS 指纹模板选择
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
//
// 返回:
// - *http.Response: HTTP 响应,调用方必须关闭 Body
// - error: 请求错误(网络错误、超时等)
//
// TLS 指纹说明:
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
// - TLS 指纹模板根据 accountID % len(profiles) 自动选择
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
// - 如果 enableTLSFingerprint=false,行为与 Do 方法相同
//
// 注意:
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
// - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响
DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error)
} }
...@@ -8,9 +8,11 @@ import ( ...@@ -8,9 +8,11 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"log/slog"
"net/http" "net/http"
"regexp" "regexp"
"strconv" "strconv"
"strings"
"time" "time"
) )
...@@ -49,6 +51,13 @@ type Fingerprint struct { ...@@ -49,6 +51,13 @@ type Fingerprint struct {
type IdentityCache interface { type IdentityCache interface {
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error) GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
// GetMaskedSessionID 获取固定的会话ID(用于会话ID伪装功能)
// 返回的 sessionID 是一个 UUID 格式的字符串
// 如果不存在或已过期(15分钟无请求),返回空字符串
GetMaskedSessionID(ctx context.Context, accountID int64) (string, error)
// SetMaskedSessionID 设置固定的会话ID,TTL 为 15 分钟
// 每次调用都会刷新 TTL
SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error
} }
// IdentityService 管理OAuth账号的请求身份指纹 // IdentityService 管理OAuth账号的请求身份指纹
...@@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI ...@@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return json.Marshal(reqMap) return json.Marshal(reqMap)
} }
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
// 先执行常规的 RewriteUserID 逻辑
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
if err != nil {
return newBody, err
}
// 检查是否启用会话ID伪装
if !account.IsSessionIDMaskingEnabled() {
return newBody, nil
}
// 解析重写后的 body,提取 user_id
var reqMap map[string]any
if err := json.Unmarshal(newBody, &reqMap); err != nil {
return newBody, nil
}
metadata, ok := reqMap["metadata"].(map[string]any)
if !ok {
return newBody, nil
}
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
return newBody, nil
}
// 查找 _session_ 的位置,替换其后的内容
const sessionMarker = "_session_"
idx := strings.LastIndex(userID, sessionMarker)
if idx == -1 {
return newBody, nil
}
// 获取或生成固定的伪装 session ID
maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID)
if err != nil {
log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err)
return newBody, nil
}
if maskedSessionID == "" {
// 首次或已过期,生成新的伪装 session ID
maskedSessionID = generateRandomUUID()
log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID)
}
// 刷新 TTL(每次请求都刷新,保持 15 分钟有效期)
if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil {
log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err)
}
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID
slog.Debug("session_id_masking_applied",
"account_id", account.ID,
"before", userID,
"after", newUserID,
)
metadata["user_id"] = newUserID
reqMap["metadata"] = metadata
return json.Marshal(reqMap)
}
// generateRandomUUID 生成随机 UUID v4 格式字符串
func generateRandomUUID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
// fallback: 使用时间戳生成
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
b = h[:16]
}
// 设置 UUID v4 版本和变体位
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x",
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
}
// generateClientID 生成64位十六进制客户端ID(32字节随机数) // generateClientID 生成64位十六进制客户端ID(32字节随机数)
func generateClientID() string { func generateClientID() string {
b := make([]byte, 32) b := make([]byte, 32)
......
...@@ -48,8 +48,7 @@ type GenerateAuthURLResult struct { ...@@ -48,8 +48,7 @@ type GenerateAuthURLResult struct {
// GenerateAuthURL generates an OAuth authorization URL with full scope // GenerateAuthURL generates an OAuth authorization URL with full scope
func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) { func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference) return s.generateAuthURLWithScope(ctx, oauth.ScopeOAuth, proxyID)
return s.generateAuthURLWithScope(ctx, scope, proxyID)
} }
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only) // GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
...@@ -176,7 +175,8 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) ( ...@@ -176,7 +175,8 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
} }
// Determine scope and if this is a setup token // Determine scope and if this is a setup token
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference) // Internal API call uses ScopeAPI (org:create_api_key not supported)
scope := oauth.ScopeAPI
isSetupToken := false isSetupToken := false
if input.Scope == "inference" { if input.Scope == "inference" {
scope = oauth.ScopeInference scope = oauth.ScopeInference
......
...@@ -394,19 +394,35 @@ func normalizeCodexTools(reqBody map[string]any) bool { ...@@ -394,19 +394,35 @@ func normalizeCodexTools(reqBody map[string]any) bool {
} }
modified := false modified := false
for idx, tool := range tools { validTools := make([]any, 0, len(tools))
for _, tool := range tools {
toolMap, ok := tool.(map[string]any) toolMap, ok := tool.(map[string]any)
if !ok { if !ok {
// Keep unknown structure as-is to avoid breaking upstream behavior.
validTools = append(validTools, tool)
continue continue
} }
toolType, _ := toolMap["type"].(string) toolType, _ := toolMap["type"].(string)
if strings.TrimSpace(toolType) != "function" { toolType = strings.TrimSpace(toolType)
if toolType != "function" {
validTools = append(validTools, toolMap)
continue continue
} }
function, ok := toolMap["function"].(map[string]any) // OpenAI Responses-style tools use top-level name/parameters.
if !ok { if name, ok := toolMap["name"].(string); ok && strings.TrimSpace(name) != "" {
validTools = append(validTools, toolMap)
continue
}
// ChatCompletions-style tools use {type:"function", function:{...}}.
functionValue, hasFunction := toolMap["function"]
function, ok := functionValue.(map[string]any)
if !hasFunction || functionValue == nil || !ok || function == nil {
// Drop invalid function tools.
modified = true
continue continue
} }
...@@ -435,11 +451,11 @@ func normalizeCodexTools(reqBody map[string]any) bool { ...@@ -435,11 +451,11 @@ func normalizeCodexTools(reqBody map[string]any) bool {
} }
} }
tools[idx] = toolMap validTools = append(validTools, toolMap)
} }
if modified { if modified {
reqBody["tools"] = tools reqBody["tools"] = validTools
} }
return modified return modified
......
...@@ -129,6 +129,37 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) { ...@@ -129,6 +129,37 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
require.False(t, hasID) require.False(t, hasID)
} }
func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) {
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"tools": []any{
map[string]any{
"type": "function",
"name": "bash",
"description": "desc",
"parameters": map[string]any{"type": "object"},
},
map[string]any{
"type": "function",
"function": nil,
},
},
}
applyCodexOAuthTransform(reqBody)
tools, ok := reqBody["tools"].([]any)
require.True(t, ok)
require.Len(t, tools, 1)
first, ok := tools[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "function", first["type"])
require.Equal(t, "bash", first["name"])
}
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
// 空 input 应保持为空且不触发异常。 // 空 input 应保持为空且不触发异常。
setupCodexCache(t) setupCodexCache(t)
......
...@@ -133,12 +133,30 @@ func NewOpenAIGatewayService( ...@@ -133,12 +133,30 @@ func NewOpenAIGatewayService(
} }
} }
// GenerateSessionHash generates session hash from header (OpenAI uses session_id header) // GenerateSessionHash generates a sticky-session hash for OpenAI requests.
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string { //
sessionID := c.GetHeader("session_id") // Priority:
// 1. Header: session_id
// 2. Header: conversation_id
// 3. Body: prompt_cache_key (opencode)
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && reqBody != nil {
if v, ok := reqBody["prompt_cache_key"].(string); ok {
sessionID = strings.TrimSpace(v)
}
}
if sessionID == "" { if sessionID == "" {
return "" return ""
} }
hash := sha256.Sum256([]byte(sessionID)) hash := sha256.Sum256([]byte(sessionID))
return hex.EncodeToString(hash[:]) return hex.EncodeToString(hash[:])
} }
......
...@@ -68,6 +68,49 @@ func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts ...@@ -68,6 +68,49 @@ func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts
return out, nil return out, nil
} }
func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
svc := &OpenAIGatewayService{}
// 1) session_id header wins
c.Request.Header.Set("session_id", "sess-123")
c.Request.Header.Set("conversation_id", "conv-456")
h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
if h1 == "" {
t.Fatalf("expected non-empty hash")
}
// 2) conversation_id used when session_id absent
c.Request.Header.Del("session_id")
h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
if h2 == "" {
t.Fatalf("expected non-empty hash")
}
if h1 == h2 {
t.Fatalf("expected different hashes for different keys")
}
// 3) prompt_cache_key used when both headers absent
c.Request.Header.Del("conversation_id")
h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
if h3 == "" {
t.Fatalf("expected non-empty hash")
}
if h2 == h3 {
t.Fatalf("expected different hashes for different keys")
}
// 4) empty when no signals
h4 := svc.GenerateSessionHash(c, map[string]any{})
if h4 != "" {
t.Fatalf("expected empty hash when no signals")
}
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now() now := time.Now()
resetAt := now.Add(10 * time.Minute) resetAt := now.Add(10 * time.Minute)
......
...@@ -27,6 +27,11 @@ var codexToolNameMapping = map[string]string{ ...@@ -27,6 +27,11 @@ var codexToolNameMapping = map[string]string{
"executeBash": "bash", "executeBash": "bash",
"exec_bash": "bash", "exec_bash": "bash",
"execBash": "bash", "execBash": "bash",
// Some clients output generic fetch names.
"fetch": "webfetch",
"web_fetch": "webfetch",
"webFetch": "webfetch",
} }
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化) // ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
...@@ -208,27 +213,67 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall ...@@ -208,27 +213,67 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
// 根据工具名称应用特定的参数修正规则 // 根据工具名称应用特定的参数修正规则
switch toolName { switch toolName {
case "bash": case "bash":
// 移除 workdir 参数(OpenCode 不支持) // OpenCode bash 支持 workdir;有些来源会输出 work_dir。
if _, exists := argsMap["workdir"]; exists { if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir {
delete(argsMap, "workdir") if workDir, exists := argsMap["work_dir"]; exists {
corrected = true argsMap["workdir"] = workDir
log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool") delete(argsMap, "work_dir")
} corrected = true
if _, exists := argsMap["work_dir"]; exists { log.Printf("[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool")
delete(argsMap, "work_dir") }
corrected = true } else {
log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool") if _, exists := argsMap["work_dir"]; exists {
delete(argsMap, "work_dir")
corrected = true
log.Printf("[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool")
}
} }
case "edit": case "edit":
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称 // OpenCode edit 参数为 filePath/oldString/newString(camelCase)。
// 这里可以添加参数名称的映射逻辑 if _, exists := argsMap["filePath"]; !exists {
if _, exists := argsMap["file_path"]; !exists { if filePath, exists := argsMap["file_path"]; exists {
if path, exists := argsMap["path"]; exists { argsMap["filePath"] = filePath
argsMap["file_path"] = path delete(argsMap, "file_path")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool")
} else if filePath, exists := argsMap["path"]; exists {
argsMap["filePath"] = filePath
delete(argsMap, "path") delete(argsMap, "path")
corrected = true corrected = true
log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool") log.Printf("[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool")
} else if filePath, exists := argsMap["file"]; exists {
argsMap["filePath"] = filePath
delete(argsMap, "file")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool")
}
}
if _, exists := argsMap["oldString"]; !exists {
if oldString, exists := argsMap["old_string"]; exists {
argsMap["oldString"] = oldString
delete(argsMap, "old_string")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
}
}
if _, exists := argsMap["newString"]; !exists {
if newString, exists := argsMap["new_string"]; exists {
argsMap["newString"] = newString
delete(argsMap, "new_string")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
}
}
if _, exists := argsMap["replaceAll"]; !exists {
if replaceAll, exists := argsMap["replace_all"]; exists {
argsMap["replaceAll"] = replaceAll
delete(argsMap, "replace_all")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
} }
} }
} }
......
...@@ -416,22 +416,23 @@ func TestCorrectToolParameters(t *testing.T) { ...@@ -416,22 +416,23 @@ func TestCorrectToolParameters(t *testing.T) {
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在 expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
}{ }{
{ {
name: "remove workdir from bash tool", name: "rename work_dir to workdir in bash tool",
input: `{ input: `{
"tool_calls": [{ "tool_calls": [{
"function": { "function": {
"name": "bash", "name": "bash",
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}" "arguments": "{\"command\":\"ls\",\"work_dir\":\"/tmp\"}"
} }
}] }]
}`, }`,
expected: map[string]bool{ expected: map[string]bool{
"command": true, "command": true,
"workdir": false, "workdir": true,
"work_dir": false,
}, },
}, },
{ {
name: "rename path to file_path in edit tool", name: "rename snake_case edit params to camelCase",
input: `{ input: `{
"tool_calls": [{ "tool_calls": [{
"function": { "function": {
...@@ -441,10 +442,12 @@ func TestCorrectToolParameters(t *testing.T) { ...@@ -441,10 +442,12 @@ func TestCorrectToolParameters(t *testing.T) {
}] }]
}`, }`,
expected: map[string]bool{ expected: map[string]bool{
"file_path": true, "filePath": true,
"path": false, "path": false,
"old_string": true, "oldString": true,
"new_string": true, "old_string": false,
"newString": true,
"new_string": false,
}, },
}, },
} }
......
...@@ -531,8 +531,8 @@ func (s *PricingService) buildModelLookupCandidates(modelLower string) []string ...@@ -531,8 +531,8 @@ func (s *PricingService) buildModelLookupCandidates(modelLower string) []string
func normalizeModelNameForPricing(model string) string { func normalizeModelNameForPricing(model string) string {
// Common Gemini/VertexAI forms: // Common Gemini/VertexAI forms:
// - models/gemini-2.0-flash-exp // - models/gemini-2.0-flash-exp
// - publishers/google/models/gemini-1.5-pro // - publishers/google/models/gemini-2.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-1.5-pro // - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
model = strings.TrimSpace(model) model = strings.TrimSpace(model)
model = strings.TrimLeft(model, "/") model = strings.TrimLeft(model, "/")
model = strings.TrimPrefix(model, "models/") model = strings.TrimPrefix(model, "models/")
......
...@@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return false return false
} }
tempMatched := false // 先尝试临时不可调度规则(401除外)
// 如果匹配成功,直接返回,不执行后续禁用逻辑
if statusCode != 401 { if statusCode != 401 {
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody) if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
return true
}
} }
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if upstreamMsg != "" { if upstreamMsg != "" {
...@@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} }
switch statusCode { switch statusCode {
case 400:
// 只有当错误信息包含 "organization has been disabled" 时才禁用
if strings.Contains(strings.ToLower(upstreamMsg), "organization has been disabled") {
msg := "Organization disabled (400): " + upstreamMsg
s.handleAuthError(ctx, account, msg)
shouldDisable = true
}
// 其他 400 错误(如参数问题)不处理,不禁用账号
case 401: case 401:
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新 // 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
if account.Type == AccountTypeOAuth { if account.Type == AccountTypeOAuth {
...@@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} }
} }
if tempMatched {
return true
}
return shouldDisable return shouldDisable
} }
...@@ -190,7 +199,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, ...@@ -190,7 +199,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start := geminiDailyWindowStart(now) start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now) totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok { if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil) stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
if err != nil { if err != nil {
return true, err return true, err
} }
...@@ -237,7 +246,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, ...@@ -237,7 +246,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if limit > 0 { if limit > 0 {
start := now.Truncate(time.Minute) start := now.Truncate(time.Minute)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil) stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
if err != nil { if err != nil {
return true, err return true, err
} }
......
...@@ -38,8 +38,9 @@ type SessionLimitCache interface { ...@@ -38,8 +38,9 @@ type SessionLimitCache interface {
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数 // GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// idleTimeouts: 每个账号的空闲超时时间配置,key 为 accountID;若为 nil 或某账号不在其中,则使用默认超时
// 返回 map[accountID]count,查询失败的账号不在 map 中 // 返回 map[accountID]count,查询失败的账号不在 map 中
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error)
// IsSessionActive 检查特定会话是否活跃(未过期) // IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
......
...@@ -69,6 +69,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -69,6 +69,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyContactInfo, SettingKeyContactInfo,
SettingKeyDocURL, SettingKeyDocURL,
SettingKeyHomeContent, SettingKeyHomeContent,
SettingKeyHideCcsImportButton,
SettingKeyLinuxDoConnectEnabled, SettingKeyLinuxDoConnectEnabled,
} }
...@@ -96,6 +97,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -96,6 +97,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
ContactInfo: settings[SettingKeyContactInfo], ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL], DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent], HomeContent: settings[SettingKeyHomeContent],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
LinuxDoOAuthEnabled: linuxDoEnabled, LinuxDoOAuthEnabled: linuxDoEnabled,
}, nil }, nil
} }
...@@ -132,6 +134,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -132,6 +134,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ContactInfo string `json:"contact_info,omitempty"` ContactInfo string `json:"contact_info,omitempty"`
DocURL string `json:"doc_url,omitempty"` DocURL string `json:"doc_url,omitempty"`
HomeContent string `json:"home_content,omitempty"` HomeContent string `json:"home_content,omitempty"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
}{ }{
...@@ -146,6 +149,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -146,6 +149,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ContactInfo: settings.ContactInfo, ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL, DocURL: settings.DocURL,
HomeContent: settings.HomeContent, HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version, Version: s.version,
}, nil }, nil
...@@ -193,6 +197,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -193,6 +197,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyContactInfo] = settings.ContactInfo updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocURL] = settings.DocURL updates[SettingKeyDocURL] = settings.DocURL
updates[SettingKeyHomeContent] = settings.HomeContent updates[SettingKeyHomeContent] = settings.HomeContent
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
// 默认配置 // 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
...@@ -339,6 +344,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -339,6 +344,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
ContactInfo: settings[SettingKeyContactInfo], ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL], DocURL: settings[SettingKeyDocURL],
HomeContent: settings[SettingKeyHomeContent], HomeContent: settings[SettingKeyHomeContent],
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
} }
// 解析整数类型 // 解析整数类型
......
...@@ -25,13 +25,14 @@ type SystemSettings struct { ...@@ -25,13 +25,14 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool LinuxDoConnectClientSecretConfigured bool
LinuxDoConnectRedirectURL string LinuxDoConnectRedirectURL string
SiteName string SiteName string
SiteLogo string SiteLogo string
SiteSubtitle string SiteSubtitle string
APIBaseURL string APIBaseURL string
ContactInfo string ContactInfo string
DocURL string DocURL string
HomeContent string HomeContent string
HideCcsImportButton bool
DefaultConcurrency int DefaultConcurrency int
DefaultBalance float64 DefaultBalance float64
...@@ -66,6 +67,7 @@ type PublicSettings struct { ...@@ -66,6 +67,7 @@ type PublicSettings struct {
ContactInfo string ContactInfo string
DocURL string DocURL string
HomeContent string HomeContent string
HideCcsImportButton bool
LinuxDoOAuthEnabled bool LinuxDoOAuthEnabled bool
Version string Version string
} }
......
...@@ -27,6 +27,7 @@ var ( ...@@ -27,6 +27,7 @@ var (
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded") ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded") ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil") ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)")
) )
// SubscriptionService 订阅服务 // SubscriptionService 订阅服务
...@@ -308,17 +309,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti ...@@ -308,17 +309,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
return nil return nil
} }
// ExtendSubscription 延长订阅 // ExtendSubscription 调整订阅时长(正数延长,负数缩短)
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*UserSubscription, error) { func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*UserSubscription, error) {
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID) sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil { if err != nil {
return nil, ErrSubscriptionNotFound return nil, ErrSubscriptionNotFound
} }
// 限制延长天数 // 限制调整天数范围
if days > MaxValidityDays { if days > MaxValidityDays {
days = MaxValidityDays days = MaxValidityDays
} }
if days < -MaxValidityDays {
days = -MaxValidityDays
}
// 计算新的过期时间 // 计算新的过期时间
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days) newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
...@@ -326,6 +330,14 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti ...@@ -326,6 +330,14 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
newExpiresAt = MaxExpiresAt newExpiresAt = MaxExpiresAt
} }
// 如果是缩短(负数),检查新的过期时间必须大于当前时间
if days < 0 {
now := time.Now()
if !newExpiresAt.After(now) {
return nil, ErrAdjustWouldExpire
}
}
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil { if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
return nil, err return nil, err
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment