Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
c8e2f614
Commit
c8e2f614
authored
Jan 20, 2026
by
cyhhao
Browse files
Merge branch 'main' of github.com:Wei-Shaw/sub2api
parents
c0347cde
c95a8649
Changes
167
Show whitespace changes
Inline
Side-by-side
backend/internal/service/dashboard_service_test.go
View file @
c8e2f614
...
...
@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start
return
nil
}
func
(
s
*
dashboardAggregationRepoStub
)
RecomputeRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardAggregationRepoStub
)
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
{
if
s
.
err
!=
nil
{
return
time
.
Time
{},
s
.
err
...
...
backend/internal/service/domain_constants.go
View file @
c8e2f614
...
...
@@ -100,6 +100,7 @@ const (
SettingKeyContactInfo
=
"contact_info"
// 客服联系方式
SettingKeyDocURL
=
"doc_url"
// 文档链接
SettingKeyHomeContent
=
"home_content"
// 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
SettingKeyHideCcsImportButton
=
"hide_ccs_import_button"
// 是否隐藏 API Keys 页面的导入 CCS 按钮
// 默认配置
SettingKeyDefaultConcurrency
=
"default_concurrency"
// 新用户默认并发量
...
...
backend/internal/service/gateway_multiplatform_test.go
View file @
c8e2f614
...
...
@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up
func
(
m
*
mockAccountRepoForPlatform
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ClearError
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
nil
}
...
...
backend/internal/service/gateway_service.go
View file @
c8e2f614
...
...
@@ -11,6 +11,8 @@ import (
"fmt"
"io"
"log"
"log/slog"
mathrand
"math/rand"
"net/http"
"os"
"regexp"
...
...
@@ -819,11 +821,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// metadataUserID:
原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
// metadataUserID:
已废弃参数,会话限制现在统一使用 sessionHash
func
(
s
*
GatewayService
)
SelectAccountWithLoadAwareness
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
metadataUserID
string
)
(
*
AccountSelectionResult
,
error
)
{
// 调试日志:记录调度入口参数
excludedIDsList
:=
make
([]
int64
,
0
,
len
(
excludedIDs
))
for
id
:=
range
excludedIDs
{
excludedIDsList
=
append
(
excludedIDsList
,
id
)
}
slog
.
Debug
(
"account_scheduling_starting"
,
"group_id"
,
derefGroupID
(
groupID
),
"model"
,
requestedModel
,
"session"
,
shortSessionHash
(
sessionHash
),
"excluded_ids"
,
excludedIDsList
)
cfg
:=
s
.
schedulingConfig
()
// 提取会话 UUID(用于会话数量限制)
sessionUUID
:=
extractSessionUUID
(
metadataUserID
)
var
stickyAccountID
int64
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
...
...
@@ -849,18 +860,39 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
if
s
.
concurrencyService
==
nil
||
!
cfg
.
LoadBatchEnabled
{
account
,
err
:=
s
.
SelectAccountForModelWithExclusions
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
// 复制排除列表,用于会话限制拒绝时的重试
localExcluded
:=
make
(
map
[
int64
]
struct
{})
for
k
,
v
:=
range
excludedIDs
{
localExcluded
[
k
]
=
v
}
for
{
account
,
err
:=
s
.
SelectAccountForModelWithExclusions
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
localExcluded
)
if
err
!=
nil
{
return
nil
,
err
}
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
account
.
ID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符)
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
result
.
ReleaseFunc
()
// 释放槽位
localExcluded
[
account
.
ID
]
=
struct
{}{}
// 排除此账号
continue
// 重新选择
}
return
&
AccountSelectionResult
{
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
// 对于等待计划的情况,也需要先检查会话限制
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
localExcluded
[
account
.
ID
]
=
struct
{}{}
continue
}
if
stickyAccountID
>
0
&&
stickyAccountID
==
account
.
ID
&&
s
.
concurrencyService
!=
nil
{
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
account
.
ID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
...
...
@@ -885,6 +917,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
},
},
nil
}
}
platform
,
hasForcePlatform
,
err
:=
s
.
resolvePlatform
(
ctx
,
groupID
,
group
)
if
err
!=
nil
{
...
...
@@ -999,7 +1032,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
stickyAccountID
,
stickyAccount
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
stickyAccount
,
session
UUID
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
stickyAccount
,
session
Hash
)
{
result
.
ReleaseFunc
()
// 释放槽位
// 继续到负载感知选择
}
else
{
...
...
@@ -1017,6 +1050,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
stickyAccountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
// 会话数量限制检查(等待计划也需要占用会话配额)
if
!
s
.
checkAndRegisterSession
(
ctx
,
stickyAccount
,
sessionHash
)
{
// 会话限制已满,继续到负载感知选择
}
else
{
return
&
AccountSelectionResult
{
Account
:
stickyAccount
,
WaitPlan
:
&
AccountWaitPlan
{
...
...
@@ -1027,6 +1064,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
},
},
nil
}
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
}
...
...
@@ -1086,7 +1124,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
item
.
account
.
ID
,
item
.
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
session
UUID
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
session
Hash
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
...
...
@@ -1104,21 +1142,27 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
acc
:=
routingAvailable
[
0
]
.
account
// 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的)
// 遍历找到第一个满足会话限制的账号
for
_
,
item
:=
range
routingAvailable
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
sessionHash
)
{
continue
// 会话限制已满,尝试下一个
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
acc
.
ID
)
log
.
Printf
(
"[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
item
.
account
.
ID
)
}
return
&
AccountSelectionResult
{
Account
:
acc
,
Account
:
item
.
account
,
WaitPlan
:
&
AccountWaitPlan
{
AccountID
:
acc
.
ID
,
MaxConcurrency
:
acc
.
Concurrency
,
AccountID
:
item
.
account
.
ID
,
MaxConcurrency
:
item
.
account
.
Concurrency
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
}
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
log
.
Printf
(
"[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection"
,
requestedModel
)
}
...
...
@@ -1137,7 +1181,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
session
UUID
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
session
Hash
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续到 Layer 2
}
else
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
...
...
@@ -1151,6 +1195,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
// 会话数量限制检查(等待计划也需要占用会话配额)
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
// 会话限制已满,继续到 Layer 2
}
else
{
return
&
AccountSelectionResult
{
Account
:
account
,
WaitPlan
:
&
AccountWaitPlan
{
...
...
@@ -1164,6 +1212,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
}
}
// ============ Layer 2: 负载感知选择 ============
candidates
:=
make
([]
*
Account
,
0
,
len
(
accounts
))
...
...
@@ -1208,7 +1257,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap
,
err
:=
s
.
concurrencyService
.
GetAccountsLoadBatch
(
ctx
,
accountLoads
)
if
err
!=
nil
{
if
result
,
ok
:=
s
.
tryAcquireByLegacyOrder
(
ctx
,
candidates
,
groupID
,
sessionHash
,
preferOAuth
,
sessionUUID
);
ok
{
if
result
,
ok
:=
s
.
tryAcquireByLegacyOrder
(
ctx
,
candidates
,
groupID
,
sessionHash
,
preferOAuth
);
ok
{
return
result
,
nil
}
}
else
{
...
...
@@ -1258,7 +1307,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
item
.
account
.
ID
,
item
.
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
session
UUID
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
session
Hash
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
...
...
@@ -1276,8 +1325,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// ============ Layer 3: 兜底排队 ============
sort
AccountsByPriorityAndLastUsed
(
candidates
,
preferOAuth
)
s
.
sort
CandidatesForFallback
(
candidates
,
preferOAuth
,
cfg
.
FallbackSelectionMode
)
for
_
,
acc
:=
range
candidates
{
// 会话数量限制检查(等待计划也需要占用会话配额)
if
!
s
.
checkAndRegisterSession
(
ctx
,
acc
,
sessionHash
)
{
continue
// 会话限制已满,尝试下一个账号
}
return
&
AccountSelectionResult
{
Account
:
acc
,
WaitPlan
:
&
AccountWaitPlan
{
...
...
@@ -1291,7 +1344,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return
nil
,
errors
.
New
(
"no available accounts"
)
}
func
(
s
*
GatewayService
)
tryAcquireByLegacyOrder
(
ctx
context
.
Context
,
candidates
[]
*
Account
,
groupID
*
int64
,
sessionHash
string
,
preferOAuth
bool
,
sessionUUID
string
)
(
*
AccountSelectionResult
,
bool
)
{
func
(
s
*
GatewayService
)
tryAcquireByLegacyOrder
(
ctx
context
.
Context
,
candidates
[]
*
Account
,
groupID
*
int64
,
sessionHash
string
,
preferOAuth
bool
)
(
*
AccountSelectionResult
,
bool
)
{
ordered
:=
append
([]
*
Account
(
nil
),
candidates
...
)
sortAccountsByPriorityAndLastUsed
(
ordered
,
preferOAuth
)
...
...
@@ -1299,7 +1352,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
acc
.
ID
,
acc
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
acc
,
session
UUID
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
acc
,
session
Hash
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
...
...
@@ -1456,7 +1509,24 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
func
(
s
*
GatewayService
)
listSchedulableAccounts
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
,
hasForcePlatform
bool
)
([]
Account
,
bool
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
return
s
.
schedulerSnapshot
.
ListSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
accounts
,
useMixed
,
err
:=
s
.
schedulerSnapshot
.
ListSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
==
nil
{
slog
.
Debug
(
"account_scheduling_list_snapshot"
,
"group_id"
,
derefGroupID
(
groupID
),
"platform"
,
platform
,
"use_mixed"
,
useMixed
,
"count"
,
len
(
accounts
))
for
_
,
acc
:=
range
accounts
{
slog
.
Debug
(
"account_scheduling_account_detail"
,
"account_id"
,
acc
.
ID
,
"name"
,
acc
.
Name
,
"platform"
,
acc
.
Platform
,
"type"
,
acc
.
Type
,
"status"
,
acc
.
Status
,
"tls_fingerprint"
,
acc
.
IsTLSFingerprintEnabled
())
}
}
return
accounts
,
useMixed
,
err
}
useMixed
:=
(
platform
==
PlatformAnthropic
||
platform
==
PlatformGemini
)
&&
!
hasForcePlatform
if
useMixed
{
...
...
@@ -1469,6 +1539,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
if
err
!=
nil
{
slog
.
Debug
(
"account_scheduling_list_failed"
,
"group_id"
,
derefGroupID
(
groupID
),
"platform"
,
platform
,
"error"
,
err
)
return
nil
,
useMixed
,
err
}
filtered
:=
make
([]
Account
,
0
,
len
(
accounts
))
...
...
@@ -1478,6 +1552,20 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
}
filtered
=
append
(
filtered
,
acc
)
}
slog
.
Debug
(
"account_scheduling_list_mixed"
,
"group_id"
,
derefGroupID
(
groupID
),
"platform"
,
platform
,
"raw_count"
,
len
(
accounts
),
"filtered_count"
,
len
(
filtered
))
for
_
,
acc
:=
range
filtered
{
slog
.
Debug
(
"account_scheduling_account_detail"
,
"account_id"
,
acc
.
ID
,
"name"
,
acc
.
Name
,
"platform"
,
acc
.
Platform
,
"type"
,
acc
.
Type
,
"status"
,
acc
.
Status
,
"tls_fingerprint"
,
acc
.
IsTLSFingerprintEnabled
())
}
return
filtered
,
useMixed
,
nil
}
...
...
@@ -1492,8 +1580,25 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
platform
)
}
if
err
!=
nil
{
slog
.
Debug
(
"account_scheduling_list_failed"
,
"group_id"
,
derefGroupID
(
groupID
),
"platform"
,
platform
,
"error"
,
err
)
return
nil
,
useMixed
,
err
}
slog
.
Debug
(
"account_scheduling_list_single"
,
"group_id"
,
derefGroupID
(
groupID
),
"platform"
,
platform
,
"count"
,
len
(
accounts
))
for
_
,
acc
:=
range
accounts
{
slog
.
Debug
(
"account_scheduling_account_detail"
,
"account_id"
,
acc
.
ID
,
"name"
,
acc
.
Name
,
"platform"
,
acc
.
Platform
,
"type"
,
acc
.
Type
,
"status"
,
acc
.
Status
,
"tls_fingerprint"
,
acc
.
IsTLSFingerprintEnabled
())
}
return
accounts
,
useMixed
,
nil
}
...
...
@@ -1559,12 +1664,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
// 缓存未命中,从数据库查询
{
var
startTime
time
.
Time
if
account
.
SessionWindowStart
!=
nil
{
startTime
=
*
account
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime
:=
account
.
GetCurrentWindowStartTime
()
stats
,
err
:=
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
account
.
ID
,
startTime
)
if
err
!=
nil
{
...
...
@@ -1597,15 +1698,16 @@ checkSchedulability:
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash)
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
func
(
s
*
GatewayService
)
checkAndRegisterSession
(
ctx
context
.
Context
,
account
*
Account
,
session
UU
ID
string
)
bool
{
func
(
s
*
GatewayService
)
checkAndRegisterSession
(
ctx
context
.
Context
,
account
*
Account
,
sessionID
string
)
bool
{
// 只检查 Anthropic OAuth/SetupToken 账号
if
!
account
.
IsAnthropicOAuthOrSetupToken
()
{
return
true
}
maxSessions
:=
account
.
GetMaxSessions
()
if
maxSessions
<=
0
||
session
UU
ID
==
""
{
if
maxSessions
<=
0
||
sessionID
==
""
{
return
true
// 未启用会话限制或无会话ID
}
...
...
@@ -1615,7 +1717,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
idleTimeout
:=
time
.
Duration
(
account
.
GetSessionIdleTimeoutMinutes
())
*
time
.
Minute
allowed
,
err
:=
s
.
sessionLimitCache
.
RegisterSession
(
ctx
,
account
.
ID
,
session
UU
ID
,
maxSessions
,
idleTimeout
)
allowed
,
err
:=
s
.
sessionLimitCache
.
RegisterSession
(
ctx
,
account
.
ID
,
sessionID
,
maxSessions
,
idleTimeout
)
if
err
!=
nil
{
// 失败开放:缓存错误时允许通过
return
true
...
...
@@ -1623,18 +1725,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
return
allowed
}
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
// 格式: user_{64位hex}_account__session_{uuid}
func
extractSessionUUID
(
metadataUserID
string
)
string
{
if
metadataUserID
==
""
{
return
""
}
if
match
:=
sessionIDRegex
.
FindStringSubmatch
(
metadataUserID
);
len
(
match
)
>
1
{
return
match
[
1
]
}
return
""
}
func
(
s
*
GatewayService
)
getSchedulableAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Account
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
return
s
.
schedulerSnapshot
.
GetAccount
(
ctx
,
accountID
)
...
...
@@ -1664,6 +1754,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
})
}
// sortCandidatesForFallback 根据配置选择排序策略
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
func
(
s
*
GatewayService
)
sortCandidatesForFallback
(
accounts
[]
*
Account
,
preferOAuth
bool
,
mode
string
)
{
if
mode
==
"random"
{
// 先按优先级排序,然后在同优先级内随机打乱
sortAccountsByPriorityOnly
(
accounts
,
preferOAuth
)
shuffleWithinPriority
(
accounts
)
}
else
{
// 默认按最后使用时间排序
sortAccountsByPriorityAndLastUsed
(
accounts
,
preferOAuth
)
}
}
// sortAccountsByPriorityOnly 仅按优先级排序
func
sortAccountsByPriorityOnly
(
accounts
[]
*
Account
,
preferOAuth
bool
)
{
sort
.
SliceStable
(
accounts
,
func
(
i
,
j
int
)
bool
{
a
,
b
:=
accounts
[
i
],
accounts
[
j
]
if
a
.
Priority
!=
b
.
Priority
{
return
a
.
Priority
<
b
.
Priority
}
if
preferOAuth
&&
a
.
Type
!=
b
.
Type
{
return
a
.
Type
==
AccountTypeOAuth
}
return
false
})
}
// shuffleWithinPriority 在同优先级内随机打乱顺序
func
shuffleWithinPriority
(
accounts
[]
*
Account
)
{
if
len
(
accounts
)
<=
1
{
return
}
r
:=
mathrand
.
New
(
mathrand
.
NewSource
(
time
.
Now
()
.
UnixNano
()))
start
:=
0
for
start
<
len
(
accounts
)
{
priority
:=
accounts
[
start
]
.
Priority
end
:=
start
+
1
for
end
<
len
(
accounts
)
&&
accounts
[
end
]
.
Priority
==
priority
{
end
++
}
// 对 [start, end) 范围内的账户随机打乱
if
end
-
start
>
1
{
r
.
Shuffle
(
end
-
start
,
func
(
i
,
j
int
)
{
accounts
[
start
+
i
],
accounts
[
start
+
j
]
=
accounts
[
start
+
j
],
accounts
[
start
+
i
]
})
}
start
=
end
}
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func
(
s
*
GatewayService
)
selectAccountForModelWithPlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
)
(
*
Account
,
error
)
{
preferOAuth
:=
platform
==
PlatformGemini
...
...
@@ -2524,6 +2664,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
proxyURL
=
account
.
Proxy
.
URL
()
}
// 调试日志:记录即将转发的账号信息
log
.
Printf
(
"[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s"
,
account
.
ID
,
account
.
Name
,
account
.
Platform
,
account
.
Type
,
account
.
IsTLSFingerprintEnabled
(),
proxyURL
)
// 重试循环
var
resp
*
http
.
Response
retryStart
:=
time
.
Now
()
...
...
@@ -2537,7 +2681,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 发送请求
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
resp
,
err
=
s
.
httpUpstream
.
Do
WithTLS
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
err
!=
nil
{
if
resp
!=
nil
&&
resp
.
Body
!=
nil
{
_
=
resp
.
Body
.
Close
()
...
...
@@ -2611,7 +2755,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
WithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
retryErr
==
nil
{
if
retryResp
.
StatusCode
<
400
{
log
.
Printf
(
"Account %d: signature error retry succeeded (thinking downgraded)"
,
account
.
ID
)
...
...
@@ -2643,7 +2787,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr2
==
nil
{
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
Do
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
Do
WithTLS
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
retryErr2
==
nil
{
resp
=
retryResp2
break
...
...
@@ -2758,6 +2902,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
// 调试日志:打印重试耗尽后的错误响应
log
.
Printf
(
"[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s"
,
account
.
ID
,
account
.
Name
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
truncateString
(
string
(
respBody
),
1000
))
s
.
handleRetryExhaustedSideEffects
(
ctx
,
resp
,
account
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
...
...
@@ -2785,6 +2933,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
// 调试日志:打印上游错误响应
log
.
Printf
(
"[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s"
,
account
.
ID
,
account
.
Name
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
truncateString
(
string
(
respBody
),
1000
))
s
.
handleFailoverSideEffects
(
ctx
,
resp
,
account
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
...
...
@@ -2914,9 +3066,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
fingerprint
=
fp
// 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid)
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
if
newBody
,
err
:=
s
.
identityService
.
RewriteUserID
(
body
,
account
.
ID
,
accountUUID
,
fp
.
ClientID
);
err
==
nil
&&
len
(
newBody
)
>
0
{
if
newBody
,
err
:=
s
.
identityService
.
RewriteUserID
WithMasking
(
ctx
,
body
,
account
,
accountUUID
,
fp
.
ClientID
);
err
==
nil
&&
len
(
newBody
)
>
0
{
body
=
newBody
}
}
...
...
@@ -3183,6 +3336,10 @@ func extractUpstreamErrorMessage(body []byte) string {
func
(
s
*
GatewayService
)
handleErrorResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
)
(
*
ForwardResult
,
error
)
{
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
// 调试日志:打印上游错误响应
log
.
Printf
(
"[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s"
,
account
.
ID
,
account
.
Name
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
truncateString
(
string
(
body
),
1000
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
...
...
@@ -4171,7 +4328,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
resp
,
err
:=
s
.
httpUpstream
.
Do
WithTLS
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
err
!=
nil
{
setOpsUpstreamError
(
c
,
0
,
sanitizeUpstreamErrorMessage
(
err
.
Error
()),
""
)
s
.
countTokensError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Request failed"
)
...
...
@@ -4193,7 +4350,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
WithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
retryErr
==
nil
{
resp
=
retryResp
respBody
,
err
=
io
.
ReadAll
(
resp
.
Body
)
...
...
@@ -4271,12 +4428,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
if
err
==
nil
{
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
if
newBody
,
err
:=
s
.
identityService
.
RewriteUserID
(
body
,
account
.
ID
,
accountUUID
,
fp
.
ClientID
);
err
==
nil
&&
len
(
newBody
)
>
0
{
if
newBody
,
err
:=
s
.
identityService
.
RewriteUserID
WithMasking
(
ctx
,
body
,
account
,
accountUUID
,
fp
.
ClientID
);
err
==
nil
&&
len
(
newBody
)
>
0
{
body
=
newBody
}
}
...
...
backend/internal/service/gemini_multiplatform_test.go
View file @
c8e2f614
...
...
@@ -88,6 +88,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda
func
(
m
*
mockAccountRepoForGemini
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ClearError
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
nil
}
...
...
@@ -599,7 +602,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
name
:
"Gemini平台-有映射配置-只支持配置的模型"
,
account
:
&
Account
{
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gemini-
1
.5-pro"
:
"x"
}},
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gemini-
2
.5-pro"
:
"x"
}},
},
model
:
"gemini-2.5-flash"
,
expected
:
false
,
...
...
backend/internal/service/http_upstream_port.go
View file @
c8e2f614
...
...
@@ -10,6 +10,7 @@ import "net/http"
// - 支持可选代理配置
// - 支持账户级连接池隔离
// - 实现类负责连接池管理和复用
// - 支持可选的 TLS 指纹伪装
type
HTTPUpstream
interface
{
// Do 执行 HTTP 请求
//
...
...
@@ -27,4 +28,28 @@ type HTTPUpstream interface {
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
// - 响应体可能已被包装以跟踪请求生命周期
Do
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
)
(
*
http
.
Response
,
error
)
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
//
// 参数:
// - req: HTTP 请求对象,由调用方构建
// - proxyURL: 代理服务器地址,空字符串表示直连
// - accountID: 账户 ID,用于连接池隔离和 TLS 指纹模板选择
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
//
// 返回:
// - *http.Response: HTTP 响应,调用方必须关闭 Body
// - error: 请求错误(网络错误、超时等)
//
// TLS 指纹说明:
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
// - TLS 指纹模板根据 accountID % len(profiles) 自动选择
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
// - 如果 enableTLSFingerprint=false,行为与 Do 方法相同
//
// 注意:
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
// - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响
DoWithTLS
(
req
*
http
.
Request
,
proxyURL
string
,
accountID
int64
,
accountConcurrency
int
,
enableTLSFingerprint
bool
)
(
*
http
.
Response
,
error
)
}
backend/internal/service/identity_service.go
View file @
c8e2f614
...
...
@@ -8,9 +8,11 @@ import (
"encoding/json"
"fmt"
"log"
"log/slog"
"net/http"
"regexp"
"strconv"
"strings"
"time"
)
...
...
@@ -49,6 +51,13 @@ type Fingerprint struct {
type
IdentityCache
interface
{
GetFingerprint
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Fingerprint
,
error
)
SetFingerprint
(
ctx
context
.
Context
,
accountID
int64
,
fp
*
Fingerprint
)
error
// GetMaskedSessionID 获取固定的会话ID(用于会话ID伪装功能)
// 返回的 sessionID 是一个 UUID 格式的字符串
// 如果不存在或已过期(15分钟无请求),返回空字符串
GetMaskedSessionID
(
ctx
context
.
Context
,
accountID
int64
)
(
string
,
error
)
// SetMaskedSessionID 设置固定的会话ID,TTL 为 15 分钟
// 每次调用都会刷新 TTL
SetMaskedSessionID
(
ctx
context
.
Context
,
accountID
int64
,
sessionID
string
)
error
}
// IdentityService 管理OAuth账号的请求身份指纹
...
...
@@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
return
json
.
Marshal
(
reqMap
)
}
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
func
(
s
*
IdentityService
)
RewriteUserIDWithMasking
(
ctx
context
.
Context
,
body
[]
byte
,
account
*
Account
,
accountUUID
,
cachedClientID
string
)
([]
byte
,
error
)
{
// 先执行常规的 RewriteUserID 逻辑
newBody
,
err
:=
s
.
RewriteUserID
(
body
,
account
.
ID
,
accountUUID
,
cachedClientID
)
if
err
!=
nil
{
return
newBody
,
err
}
// 检查是否启用会话ID伪装
if
!
account
.
IsSessionIDMaskingEnabled
()
{
return
newBody
,
nil
}
// 解析重写后的 body,提取 user_id
var
reqMap
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
newBody
,
&
reqMap
);
err
!=
nil
{
return
newBody
,
nil
}
metadata
,
ok
:=
reqMap
[
"metadata"
]
.
(
map
[
string
]
any
)
if
!
ok
{
return
newBody
,
nil
}
userID
,
ok
:=
metadata
[
"user_id"
]
.
(
string
)
if
!
ok
||
userID
==
""
{
return
newBody
,
nil
}
// 查找 _session_ 的位置,替换其后的内容
const
sessionMarker
=
"_session_"
idx
:=
strings
.
LastIndex
(
userID
,
sessionMarker
)
if
idx
==
-
1
{
return
newBody
,
nil
}
// 获取或生成固定的伪装 session ID
maskedSessionID
,
err
:=
s
.
cache
.
GetMaskedSessionID
(
ctx
,
account
.
ID
)
if
err
!=
nil
{
log
.
Printf
(
"Warning: failed to get masked session ID for account %d: %v"
,
account
.
ID
,
err
)
return
newBody
,
nil
}
if
maskedSessionID
==
""
{
// 首次或已过期,生成新的伪装 session ID
maskedSessionID
=
generateRandomUUID
()
log
.
Printf
(
"Generated new masked session ID for account %d: %s"
,
account
.
ID
,
maskedSessionID
)
}
// 刷新 TTL(每次请求都刷新,保持 15 分钟有效期)
if
err
:=
s
.
cache
.
SetMaskedSessionID
(
ctx
,
account
.
ID
,
maskedSessionID
);
err
!=
nil
{
log
.
Printf
(
"Warning: failed to set masked session ID for account %d: %v"
,
account
.
ID
,
err
)
}
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
newUserID
:=
userID
[
:
idx
+
len
(
sessionMarker
)]
+
maskedSessionID
slog
.
Debug
(
"session_id_masking_applied"
,
"account_id"
,
account
.
ID
,
"before"
,
userID
,
"after"
,
newUserID
,
)
metadata
[
"user_id"
]
=
newUserID
reqMap
[
"metadata"
]
=
metadata
return
json
.
Marshal
(
reqMap
)
}
// generateRandomUUID 生成随机 UUID v4 格式字符串
func
generateRandomUUID
()
string
{
b
:=
make
([]
byte
,
16
)
if
_
,
err
:=
rand
.
Read
(
b
);
err
!=
nil
{
// fallback: 使用时间戳生成
h
:=
sha256
.
Sum256
([]
byte
(
fmt
.
Sprintf
(
"%d"
,
time
.
Now
()
.
UnixNano
())))
b
=
h
[
:
16
]
}
// 设置 UUID v4 版本和变体位
b
[
6
]
=
(
b
[
6
]
&
0x0f
)
|
0x40
b
[
8
]
=
(
b
[
8
]
&
0x3f
)
|
0x80
return
fmt
.
Sprintf
(
"%x-%x-%x-%x-%x"
,
b
[
0
:
4
],
b
[
4
:
6
],
b
[
6
:
8
],
b
[
8
:
10
],
b
[
10
:
16
])
}
// generateClientID 生成64位十六进制客户端ID(32字节随机数)
func
generateClientID
()
string
{
b
:=
make
([]
byte
,
32
)
...
...
backend/internal/service/oauth_service.go
View file @
c8e2f614
...
...
@@ -48,8 +48,7 @@ type GenerateAuthURLResult struct {
// GenerateAuthURL generates an OAuth authorization URL with full scope
func
(
s
*
OAuthService
)
GenerateAuthURL
(
ctx
context
.
Context
,
proxyID
*
int64
)
(
*
GenerateAuthURLResult
,
error
)
{
scope
:=
fmt
.
Sprintf
(
"%s %s"
,
oauth
.
ScopeProfile
,
oauth
.
ScopeInference
)
return
s
.
generateAuthURLWithScope
(
ctx
,
scope
,
proxyID
)
return
s
.
generateAuthURLWithScope
(
ctx
,
oauth
.
ScopeOAuth
,
proxyID
)
}
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
...
...
@@ -176,7 +175,8 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
}
// Determine scope and if this is a setup token
scope
:=
fmt
.
Sprintf
(
"%s %s"
,
oauth
.
ScopeProfile
,
oauth
.
ScopeInference
)
// Internal API call uses ScopeAPI (org:create_api_key not supported)
scope
:=
oauth
.
ScopeAPI
isSetupToken
:=
false
if
input
.
Scope
==
"inference"
{
scope
=
oauth
.
ScopeInference
...
...
backend/internal/service/openai_codex_transform.go
View file @
c8e2f614
...
...
@@ -394,19 +394,35 @@ func normalizeCodexTools(reqBody map[string]any) bool {
}
modified
:=
false
for
idx
,
tool
:=
range
tools
{
validTools
:=
make
([]
any
,
0
,
len
(
tools
))
for
_
,
tool
:=
range
tools
{
toolMap
,
ok
:=
tool
.
(
map
[
string
]
any
)
if
!
ok
{
// Keep unknown structure as-is to avoid breaking upstream behavior.
validTools
=
append
(
validTools
,
tool
)
continue
}
toolType
,
_
:=
toolMap
[
"type"
]
.
(
string
)
if
strings
.
TrimSpace
(
toolType
)
!=
"function"
{
toolType
=
strings
.
TrimSpace
(
toolType
)
if
toolType
!=
"function"
{
validTools
=
append
(
validTools
,
toolMap
)
continue
}
function
,
ok
:=
toolMap
[
"function"
]
.
(
map
[
string
]
any
)
if
!
ok
{
// OpenAI Responses-style tools use top-level name/parameters.
if
name
,
ok
:=
toolMap
[
"name"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
name
)
!=
""
{
validTools
=
append
(
validTools
,
toolMap
)
continue
}
// ChatCompletions-style tools use {type:"function", function:{...}}.
functionValue
,
hasFunction
:=
toolMap
[
"function"
]
function
,
ok
:=
functionValue
.
(
map
[
string
]
any
)
if
!
hasFunction
||
functionValue
==
nil
||
!
ok
||
function
==
nil
{
// Drop invalid function tools.
modified
=
true
continue
}
...
...
@@ -435,11 +451,11 @@ func normalizeCodexTools(reqBody map[string]any) bool {
}
}
tools
[
idx
]
=
toolMap
validTools
=
append
(
validTools
,
toolMap
)
}
if
modified
{
reqBody
[
"tools"
]
=
t
ools
reqBody
[
"tools"
]
=
validT
ools
}
return
modified
...
...
backend/internal/service/openai_codex_transform_test.go
View file @
c8e2f614
...
...
@@ -129,6 +129,37 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
require
.
False
(
t
,
hasID
)
}
func
TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools
(
t
*
testing
.
T
)
{
setupCodexCache
(
t
)
reqBody
:=
map
[
string
]
any
{
"model"
:
"gpt-5.1"
,
"tools"
:
[]
any
{
map
[
string
]
any
{
"type"
:
"function"
,
"name"
:
"bash"
,
"description"
:
"desc"
,
"parameters"
:
map
[
string
]
any
{
"type"
:
"object"
},
},
map
[
string
]
any
{
"type"
:
"function"
,
"function"
:
nil
,
},
},
}
applyCodexOAuthTransform
(
reqBody
)
tools
,
ok
:=
reqBody
[
"tools"
]
.
([]
any
)
require
.
True
(
t
,
ok
)
require
.
Len
(
t
,
tools
,
1
)
first
,
ok
:=
tools
[
0
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"function"
,
first
[
"type"
])
require
.
Equal
(
t
,
"bash"
,
first
[
"name"
])
}
func
TestApplyCodexOAuthTransform_EmptyInput
(
t
*
testing
.
T
)
{
// 空 input 应保持为空且不触发异常。
setupCodexCache
(
t
)
...
...
backend/internal/service/openai_gateway_service.go
View file @
c8e2f614
...
...
@@ -133,12 +133,30 @@ func NewOpenAIGatewayService(
}
}
// GenerateSessionHash generates session hash from header (OpenAI uses session_id header)
func
(
s
*
OpenAIGatewayService
)
GenerateSessionHash
(
c
*
gin
.
Context
)
string
{
sessionID
:=
c
.
GetHeader
(
"session_id"
)
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
//
// Priority:
// 1. Header: session_id
// 2. Header: conversation_id
// 3. Body: prompt_cache_key (opencode)
func
(
s
*
OpenAIGatewayService
)
GenerateSessionHash
(
c
*
gin
.
Context
,
reqBody
map
[
string
]
any
)
string
{
if
c
==
nil
{
return
""
}
sessionID
:=
strings
.
TrimSpace
(
c
.
GetHeader
(
"session_id"
))
if
sessionID
==
""
{
sessionID
=
strings
.
TrimSpace
(
c
.
GetHeader
(
"conversation_id"
))
}
if
sessionID
==
""
&&
reqBody
!=
nil
{
if
v
,
ok
:=
reqBody
[
"prompt_cache_key"
]
.
(
string
);
ok
{
sessionID
=
strings
.
TrimSpace
(
v
)
}
}
if
sessionID
==
""
{
return
""
}
hash
:=
sha256
.
Sum256
([]
byte
(
sessionID
))
return
hex
.
EncodeToString
(
hash
[
:
])
}
...
...
backend/internal/service/openai_gateway_service_test.go
View file @
c8e2f614
...
...
@@ -68,6 +68,49 @@ func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts
return
out
,
nil
}
func
TestOpenAIGatewayService_GenerateSessionHash_Priority
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/openai/v1/responses"
,
nil
)
svc
:=
&
OpenAIGatewayService
{}
// 1) session_id header wins
c
.
Request
.
Header
.
Set
(
"session_id"
,
"sess-123"
)
c
.
Request
.
Header
.
Set
(
"conversation_id"
,
"conv-456"
)
h1
:=
svc
.
GenerateSessionHash
(
c
,
map
[
string
]
any
{
"prompt_cache_key"
:
"ses_aaa"
})
if
h1
==
""
{
t
.
Fatalf
(
"expected non-empty hash"
)
}
// 2) conversation_id used when session_id absent
c
.
Request
.
Header
.
Del
(
"session_id"
)
h2
:=
svc
.
GenerateSessionHash
(
c
,
map
[
string
]
any
{
"prompt_cache_key"
:
"ses_aaa"
})
if
h2
==
""
{
t
.
Fatalf
(
"expected non-empty hash"
)
}
if
h1
==
h2
{
t
.
Fatalf
(
"expected different hashes for different keys"
)
}
// 3) prompt_cache_key used when both headers absent
c
.
Request
.
Header
.
Del
(
"conversation_id"
)
h3
:=
svc
.
GenerateSessionHash
(
c
,
map
[
string
]
any
{
"prompt_cache_key"
:
"ses_aaa"
})
if
h3
==
""
{
t
.
Fatalf
(
"expected non-empty hash"
)
}
if
h2
==
h3
{
t
.
Fatalf
(
"expected different hashes for different keys"
)
}
// 4) empty when no signals
h4
:=
svc
.
GenerateSessionHash
(
c
,
map
[
string
]
any
{})
if
h4
!=
""
{
t
.
Fatalf
(
"expected empty hash when no signals"
)
}
}
func
TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
resetAt
:=
now
.
Add
(
10
*
time
.
Minute
)
...
...
backend/internal/service/openai_tool_corrector.go
View file @
c8e2f614
...
...
@@ -27,6 +27,11 @@ var codexToolNameMapping = map[string]string{
"executeBash"
:
"bash"
,
"exec_bash"
:
"bash"
,
"execBash"
:
"bash"
,
// Some clients output generic fetch names.
"fetch"
:
"webfetch"
,
"web_fetch"
:
"webfetch"
,
"webFetch"
:
"webfetch"
,
}
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
...
...
@@ -208,27 +213,67 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
// 根据工具名称应用特定的参数修正规则
switch
toolName
{
case
"bash"
:
// 移除 workdir 参数(OpenCode 不支持)
if
_
,
exists
:=
argsMap
[
"workdir"
];
exists
{
delete
(
argsMap
,
"workdir"
)
// OpenCode bash 支持 workdir;有些来源会输出 work_dir。
if
_
,
hasWorkdir
:=
argsMap
[
"workdir"
];
!
hasWorkdir
{
if
workDir
,
exists
:=
argsMap
[
"work_dir"
];
exists
{
argsMap
[
"workdir"
]
=
workDir
delete
(
argsMap
,
"work_dir"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Re
mov
ed 'workdir'
parameter from
bash tool"
)
log
.
Printf
(
"[CodexToolCorrector] Re
nam
ed 'work
_
dir'
to 'workdir' in
bash tool"
)
}
}
else
{
if
_
,
exists
:=
argsMap
[
"work_dir"
];
exists
{
delete
(
argsMap
,
"work_dir"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Removed 'work_dir' parameter from bash tool"
)
log
.
Printf
(
"[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool"
)
}
}
case
"edit"
:
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
// 这里可以添加参数名称的映射逻辑
if
_
,
exists
:=
argsMap
[
"file_path"
];
!
exists
{
if
path
,
exists
:=
argsMap
[
"path"
];
exists
{
argsMap
[
"file_path"
]
=
path
// OpenCode edit 参数为 filePath/oldString/newString(camelCase)。
if
_
,
exists
:=
argsMap
[
"filePath"
];
!
exists
{
if
filePath
,
exists
:=
argsMap
[
"file_path"
];
exists
{
argsMap
[
"filePath"
]
=
filePath
delete
(
argsMap
,
"file_path"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool"
)
}
else
if
filePath
,
exists
:=
argsMap
[
"path"
];
exists
{
argsMap
[
"filePath"
]
=
filePath
delete
(
argsMap
,
"path"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool"
)
log
.
Printf
(
"[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool"
)
}
else
if
filePath
,
exists
:=
argsMap
[
"file"
];
exists
{
argsMap
[
"filePath"
]
=
filePath
delete
(
argsMap
,
"file"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool"
)
}
}
if
_
,
exists
:=
argsMap
[
"oldString"
];
!
exists
{
if
oldString
,
exists
:=
argsMap
[
"old_string"
];
exists
{
argsMap
[
"oldString"
]
=
oldString
delete
(
argsMap
,
"old_string"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool"
)
}
}
if
_
,
exists
:=
argsMap
[
"newString"
];
!
exists
{
if
newString
,
exists
:=
argsMap
[
"new_string"
];
exists
{
argsMap
[
"newString"
]
=
newString
delete
(
argsMap
,
"new_string"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool"
)
}
}
if
_
,
exists
:=
argsMap
[
"replaceAll"
];
!
exists
{
if
replaceAll
,
exists
:=
argsMap
[
"replace_all"
];
exists
{
argsMap
[
"replaceAll"
]
=
replaceAll
delete
(
argsMap
,
"replace_all"
)
corrected
=
true
log
.
Printf
(
"[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool"
)
}
}
}
...
...
backend/internal/service/openai_tool_corrector_test.go
View file @
c8e2f614
...
...
@@ -416,22 +416,23 @@ func TestCorrectToolParameters(t *testing.T) {
expected
map
[
string
]
bool
// key: 期待存在的参数, value: true表示应该存在
}{
{
name
:
"re
mov
e workdir
from
bash tool"
,
name
:
"re
nam
e work
_
dir
to workdir in
bash tool"
,
input
:
`{
"tool_calls": [{
"function": {
"name": "bash",
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
"arguments": "{\"command\":\"ls\",\"work
_
dir\":\"/tmp\"}"
}
}]
}`
,
expected
:
map
[
string
]
bool
{
"command"
:
true
,
"workdir"
:
false
,
"workdir"
:
true
,
"work_dir"
:
false
,
},
},
{
name
:
"rename
path to file_path in edit tool
"
,
name
:
"rename
snake_case edit params to camelCase
"
,
input
:
`{
"tool_calls": [{
"function": {
...
...
@@ -441,10 +442,12 @@ func TestCorrectToolParameters(t *testing.T) {
}]
}`
,
expected
:
map
[
string
]
bool
{
"file
_p
ath"
:
true
,
"file
P
ath"
:
true
,
"path"
:
false
,
"old_string"
:
true
,
"new_string"
:
true
,
"oldString"
:
true
,
"old_string"
:
false
,
"newString"
:
true
,
"new_string"
:
false
,
},
},
}
...
...
backend/internal/service/pricing_service.go
View file @
c8e2f614
...
...
@@ -531,8 +531,8 @@ func (s *PricingService) buildModelLookupCandidates(modelLower string) []string
func
normalizeModelNameForPricing
(
model
string
)
string
{
// Common Gemini/VertexAI forms:
// - models/gemini-2.0-flash-exp
// - publishers/google/models/gemini-
1
.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-
1
.5-pro
// - publishers/google/models/gemini-
2
.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-
2
.5-pro
model
=
strings
.
TrimSpace
(
model
)
model
=
strings
.
TrimLeft
(
model
,
"/"
)
model
=
strings
.
TrimPrefix
(
model
,
"models/"
)
...
...
backend/internal/service/ratelimit_service.go
View file @
c8e2f614
...
...
@@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return
false
}
tempMatched
:=
false
// 先尝试临时不可调度规则(401除外)
// 如果匹配成功,直接返回,不执行后续禁用逻辑
if
statusCode
!=
401
{
tempMatched
=
s
.
tryTempUnschedulable
(
ctx
,
account
,
statusCode
,
responseBody
)
if
s
.
tryTempUnschedulable
(
ctx
,
account
,
statusCode
,
responseBody
)
{
return
true
}
}
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
responseBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
upstreamMsg
!=
""
{
...
...
@@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
switch
statusCode
{
case
400
:
// 只有当错误信息包含 "organization has been disabled" 时才禁用
if
strings
.
Contains
(
strings
.
ToLower
(
upstreamMsg
),
"organization has been disabled"
)
{
msg
:=
"Organization disabled (400): "
+
upstreamMsg
s
.
handleAuthError
(
ctx
,
account
,
msg
)
shouldDisable
=
true
}
// 其他 400 错误(如参数问题)不处理,不禁用账号
case
401
:
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
if
account
.
Type
==
AccountTypeOAuth
{
...
...
@@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
}
if
tempMatched
{
return
true
}
return
shouldDisable
}
...
...
@@ -190,7 +199,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start
:=
geminiDailyWindowStart
(
now
)
totals
,
ok
:=
s
.
getGeminiUsageTotals
(
account
.
ID
,
start
,
now
)
if
!
ok
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
if
err
!=
nil
{
return
true
,
err
}
...
...
@@ -237,7 +246,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if
limit
>
0
{
start
:=
now
.
Truncate
(
time
.
Minute
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
,
nil
)
if
err
!=
nil
{
return
true
,
err
}
...
...
backend/internal/service/session_limit_cache.go
View file @
c8e2f614
...
...
@@ -38,8 +38,9 @@ type SessionLimitCache interface {
GetActiveSessionCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// idleTimeouts: 每个账号的空闲超时时间配置,key 为 accountID;若为 nil 或某账号不在其中,则使用默认超时
// 返回 map[accountID]count,查询失败的账号不在 map 中
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
,
idleTimeouts
map
[
int64
]
time
.
Duration
)
(
map
[
int64
]
int
,
error
)
// IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
)
(
bool
,
error
)
...
...
backend/internal/service/setting_service.go
View file @
c8e2f614
...
...
@@ -69,6 +69,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyContactInfo
,
SettingKeyDocURL
,
SettingKeyHomeContent
,
SettingKeyHideCcsImportButton
,
SettingKeyLinuxDoConnectEnabled
,
}
...
...
@@ -96,6 +97,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocURL
:
settings
[
SettingKeyDocURL
],
HomeContent
:
settings
[
SettingKeyHomeContent
],
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
LinuxDoOAuthEnabled
:
linuxDoEnabled
,
},
nil
}
...
...
@@ -132,6 +134,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ContactInfo
string
`json:"contact_info,omitempty"`
DocURL
string
`json:"doc_url,omitempty"`
HomeContent
string
`json:"home_content,omitempty"`
HideCcsImportButton
bool
`json:"hide_ccs_import_button"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
Version
string
`json:"version,omitempty"`
}{
...
...
@@ -146,6 +149,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
ContactInfo
:
settings
.
ContactInfo
,
DocURL
:
settings
.
DocURL
,
HomeContent
:
settings
.
HomeContent
,
HideCcsImportButton
:
settings
.
HideCcsImportButton
,
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
Version
:
s
.
version
,
},
nil
...
...
@@ -193,6 +197,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyContactInfo
]
=
settings
.
ContactInfo
updates
[
SettingKeyDocURL
]
=
settings
.
DocURL
updates
[
SettingKeyHomeContent
]
=
settings
.
HomeContent
updates
[
SettingKeyHideCcsImportButton
]
=
strconv
.
FormatBool
(
settings
.
HideCcsImportButton
)
// 默认配置
updates
[
SettingKeyDefaultConcurrency
]
=
strconv
.
Itoa
(
settings
.
DefaultConcurrency
)
...
...
@@ -339,6 +344,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
ContactInfo
:
settings
[
SettingKeyContactInfo
],
DocURL
:
settings
[
SettingKeyDocURL
],
HomeContent
:
settings
[
SettingKeyHomeContent
],
HideCcsImportButton
:
settings
[
SettingKeyHideCcsImportButton
]
==
"true"
,
}
// 解析整数类型
...
...
backend/internal/service/settings_view.go
View file @
c8e2f614
...
...
@@ -32,6 +32,7 @@ type SystemSettings struct {
ContactInfo
string
DocURL
string
HomeContent
string
HideCcsImportButton
bool
DefaultConcurrency
int
DefaultBalance
float64
...
...
@@ -66,6 +67,7 @@ type PublicSettings struct {
ContactInfo
string
DocURL
string
HomeContent
string
HideCcsImportButton
bool
LinuxDoOAuthEnabled
bool
Version
string
}
...
...
backend/internal/service/subscription_service.go
View file @
c8e2f614
...
...
@@ -27,6 +27,7 @@ var (
ErrWeeklyLimitExceeded
=
infraerrors
.
TooManyRequests
(
"WEEKLY_LIMIT_EXCEEDED"
,
"weekly usage limit exceeded"
)
ErrMonthlyLimitExceeded
=
infraerrors
.
TooManyRequests
(
"MONTHLY_LIMIT_EXCEEDED"
,
"monthly usage limit exceeded"
)
ErrSubscriptionNilInput
=
infraerrors
.
BadRequest
(
"SUBSCRIPTION_NIL_INPUT"
,
"subscription input cannot be nil"
)
ErrAdjustWouldExpire
=
infraerrors
.
BadRequest
(
"ADJUST_WOULD_EXPIRE"
,
"adjustment would result in expired subscription (remaining days must be > 0)"
)
)
// SubscriptionService 订阅服务
...
...
@@ -308,17 +309,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
return
nil
}
// ExtendSubscription
延长订阅
// ExtendSubscription
调整订阅时长(正数延长,负数缩短)
func
(
s
*
SubscriptionService
)
ExtendSubscription
(
ctx
context
.
Context
,
subscriptionID
int64
,
days
int
)
(
*
UserSubscription
,
error
)
{
sub
,
err
:=
s
.
userSubRepo
.
GetByID
(
ctx
,
subscriptionID
)
if
err
!=
nil
{
return
nil
,
ErrSubscriptionNotFound
}
// 限制
延长天数
// 限制
调整天数范围
if
days
>
MaxValidityDays
{
days
=
MaxValidityDays
}
if
days
<
-
MaxValidityDays
{
days
=
-
MaxValidityDays
}
// 计算新的过期时间
newExpiresAt
:=
sub
.
ExpiresAt
.
AddDate
(
0
,
0
,
days
)
...
...
@@ -326,6 +330,14 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
newExpiresAt
=
MaxExpiresAt
}
// 如果是缩短(负数),检查新的过期时间必须大于当前时间
if
days
<
0
{
now
:=
time
.
Now
()
if
!
newExpiresAt
.
After
(
now
)
{
return
nil
,
ErrAdjustWouldExpire
}
}
if
err
:=
s
.
userSubRepo
.
ExtendExpiry
(
ctx
,
subscriptionID
,
newExpiresAt
);
err
!=
nil
{
return
nil
,
err
}
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment