Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
dd7f2124
Commit
dd7f2124
authored
Jan 19, 2026
by
cyhhao
Browse files
merge: resolve conflicts with main
parents
49be9d08
bba5b3c0
Changes
42
Expand all
Show whitespace changes
Inline
Side-by-side
backend/cmd/server/wire_gen.go
View file @
dd7f2124
...
@@ -118,7 +118,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -118,7 +118,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache
:=
repository
.
ProvideConcurrencyCache
(
redisClient
,
configConfig
)
concurrencyCache
:=
repository
.
ProvideConcurrencyCache
(
redisClient
,
configConfig
)
concurrencyService
:=
service
.
ProvideConcurrencyService
(
concurrencyCache
,
accountRepository
,
configConfig
)
concurrencyService
:=
service
.
ProvideConcurrencyService
(
concurrencyCache
,
accountRepository
,
configConfig
)
crsSyncService
:=
service
.
NewCRSSyncService
(
accountRepository
,
proxyRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
configConfig
)
crsSyncService
:=
service
.
NewCRSSyncService
(
accountRepository
,
proxyRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
configConfig
)
accountHandler
:=
admin
.
NewAccountHandler
(
adminService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
rateLimitService
,
accountUsageService
,
accountTestService
,
concurrencyService
,
crsSyncService
)
sessionLimitCache
:=
repository
.
ProvideSessionLimitCache
(
redisClient
,
configConfig
)
accountHandler
:=
admin
.
NewAccountHandler
(
adminService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
antigravityOAuthService
,
rateLimitService
,
accountUsageService
,
accountTestService
,
concurrencyService
,
crsSyncService
,
sessionLimitCache
)
oAuthHandler
:=
admin
.
NewOAuthHandler
(
oAuthService
)
oAuthHandler
:=
admin
.
NewOAuthHandler
(
oAuthService
)
openAIOAuthHandler
:=
admin
.
NewOpenAIOAuthHandler
(
openAIOAuthService
,
adminService
)
openAIOAuthHandler
:=
admin
.
NewOpenAIOAuthHandler
(
openAIOAuthService
,
adminService
)
geminiOAuthHandler
:=
admin
.
NewGeminiOAuthHandler
(
geminiOAuthService
)
geminiOAuthHandler
:=
admin
.
NewGeminiOAuthHandler
(
geminiOAuthService
)
...
@@ -140,7 +141,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -140,7 +141,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService
:=
service
.
NewIdentityService
(
identityCache
)
identityService
:=
service
.
NewIdentityService
(
identityCache
)
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
claudeTokenProvider
:=
service
.
NewClaudeTokenProvider
(
accountRepository
,
geminiTokenCache
,
oAuthService
)
claudeTokenProvider
:=
service
.
NewClaudeTokenProvider
(
accountRepository
,
geminiTokenCache
,
oAuthService
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
,
claudeTokenProvider
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
,
claudeTokenProvider
,
sessionLimitCache
)
openAITokenProvider
:=
service
.
NewOpenAITokenProvider
(
accountRepository
,
geminiTokenCache
,
openAIOAuthService
)
openAITokenProvider
:=
service
.
NewOpenAITokenProvider
(
accountRepository
,
geminiTokenCache
,
openAIOAuthService
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
,
openAITokenProvider
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
schedulerSnapshotService
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
,
openAITokenProvider
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
schedulerSnapshotService
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
,
configConfig
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
schedulerSnapshotService
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
,
configConfig
)
...
...
backend/internal/config/config.go
View file @
dd7f2124
...
@@ -234,6 +234,10 @@ type GatewayConfig struct {
...
@@ -234,6 +234,10 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes
int
`mapstructure:"concurrency_slot_ttl_minutes"`
ConcurrencySlotTTLMinutes
int
`mapstructure:"concurrency_slot_ttl_minutes"`
// SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能
// 空闲超过此时间的会话将被自动释放
SessionIdleTimeoutMinutes
int
`mapstructure:"session_idle_timeout_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
StreamDataIntervalTimeout
int
`mapstructure:"stream_data_interval_timeout"`
StreamDataIntervalTimeout
int
`mapstructure:"stream_data_interval_timeout"`
...
...
backend/internal/handler/admin/account_handler.go
View file @
dd7f2124
...
@@ -44,6 +44,7 @@ type AccountHandler struct {
...
@@ -44,6 +44,7 @@ type AccountHandler struct {
accountTestService
*
service
.
AccountTestService
accountTestService
*
service
.
AccountTestService
concurrencyService
*
service
.
ConcurrencyService
concurrencyService
*
service
.
ConcurrencyService
crsSyncService
*
service
.
CRSSyncService
crsSyncService
*
service
.
CRSSyncService
sessionLimitCache
service
.
SessionLimitCache
}
}
// NewAccountHandler creates a new admin account handler
// NewAccountHandler creates a new admin account handler
...
@@ -58,6 +59,7 @@ func NewAccountHandler(
...
@@ -58,6 +59,7 @@ func NewAccountHandler(
accountTestService
*
service
.
AccountTestService
,
accountTestService
*
service
.
AccountTestService
,
concurrencyService
*
service
.
ConcurrencyService
,
concurrencyService
*
service
.
ConcurrencyService
,
crsSyncService
*
service
.
CRSSyncService
,
crsSyncService
*
service
.
CRSSyncService
,
sessionLimitCache
service
.
SessionLimitCache
,
)
*
AccountHandler
{
)
*
AccountHandler
{
return
&
AccountHandler
{
return
&
AccountHandler
{
adminService
:
adminService
,
adminService
:
adminService
,
...
@@ -70,6 +72,7 @@ func NewAccountHandler(
...
@@ -70,6 +72,7 @@ func NewAccountHandler(
accountTestService
:
accountTestService
,
accountTestService
:
accountTestService
,
concurrencyService
:
concurrencyService
,
concurrencyService
:
concurrencyService
,
crsSyncService
:
crsSyncService
,
crsSyncService
:
crsSyncService
,
sessionLimitCache
:
sessionLimitCache
,
}
}
}
}
...
@@ -130,6 +133,9 @@ type BulkUpdateAccountsRequest struct {
...
@@ -130,6 +133,9 @@ type BulkUpdateAccountsRequest struct {
type
AccountWithConcurrency
struct
{
type
AccountWithConcurrency
struct
{
*
dto
.
Account
*
dto
.
Account
CurrentConcurrency
int
`json:"current_concurrency"`
CurrentConcurrency
int
`json:"current_concurrency"`
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost
*
float64
`json:"current_window_cost,omitempty"`
// 当前窗口费用
ActiveSessions
*
int
`json:"active_sessions,omitempty"`
// 当前活跃会话数
}
}
// List handles listing all accounts with pagination
// List handles listing all accounts with pagination
...
@@ -164,13 +170,89 @@ func (h *AccountHandler) List(c *gin.Context) {
...
@@ -164,13 +170,89 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts
=
make
(
map
[
int64
]
int
)
concurrencyCounts
=
make
(
map
[
int64
]
int
)
}
}
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs
:=
make
([]
int64
,
0
)
sessionLimitAccountIDs
:=
make
([]
int64
,
0
)
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
acc
.
IsAnthropicOAuthOrSetupToken
()
{
if
acc
.
GetWindowCostLimit
()
>
0
{
windowCostAccountIDs
=
append
(
windowCostAccountIDs
,
acc
.
ID
)
}
if
acc
.
GetMaxSessions
()
>
0
{
sessionLimitAccountIDs
=
append
(
sessionLimitAccountIDs
,
acc
.
ID
)
}
}
}
// 并行获取窗口费用和活跃会话数
var
windowCosts
map
[
int64
]
float64
var
activeSessions
map
[
int64
]
int
// 获取活跃会话数(批量查询)
if
len
(
sessionLimitAccountIDs
)
>
0
&&
h
.
sessionLimitCache
!=
nil
{
activeSessions
,
_
=
h
.
sessionLimitCache
.
GetActiveSessionCountBatch
(
c
.
Request
.
Context
(),
sessionLimitAccountIDs
)
if
activeSessions
==
nil
{
activeSessions
=
make
(
map
[
int64
]
int
)
}
}
// 获取窗口费用(并行查询)
if
len
(
windowCostAccountIDs
)
>
0
{
windowCosts
=
make
(
map
[
int64
]
float64
)
var
mu
sync
.
Mutex
g
,
gctx
:=
errgroup
.
WithContext
(
c
.
Request
.
Context
())
g
.
SetLimit
(
10
)
// 限制并发数
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
!
acc
.
IsAnthropicOAuthOrSetupToken
()
||
acc
.
GetWindowCostLimit
()
<=
0
{
continue
}
accCopy
:=
acc
// 闭包捕获
g
.
Go
(
func
()
error
{
var
startTime
time
.
Time
if
accCopy
.
SessionWindowStart
!=
nil
{
startTime
=
*
accCopy
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
stats
,
err
:=
h
.
accountUsageService
.
GetAccountWindowStats
(
gctx
,
accCopy
.
ID
,
startTime
)
if
err
==
nil
&&
stats
!=
nil
{
mu
.
Lock
()
windowCosts
[
accCopy
.
ID
]
=
stats
.
StandardCost
// 使用标准费用
mu
.
Unlock
()
}
return
nil
// 不返回错误,允许部分失败
})
}
_
=
g
.
Wait
()
}
// Build response with concurrency info
// Build response with concurrency info
result
:=
make
([]
AccountWithConcurrency
,
len
(
accounts
))
result
:=
make
([]
AccountWithConcurrency
,
len
(
accounts
))
for
i
:=
range
accounts
{
for
i
:=
range
accounts
{
result
[
i
]
=
AccountWithConcurrency
{
acc
:=
&
accounts
[
i
]
Account
:
dto
.
AccountFromService
(
&
accounts
[
i
]),
item
:=
AccountWithConcurrency
{
CurrentConcurrency
:
concurrencyCounts
[
accounts
[
i
]
.
ID
],
Account
:
dto
.
AccountFromService
(
acc
),
CurrentConcurrency
:
concurrencyCounts
[
acc
.
ID
],
}
// 添加窗口费用(仅当启用时)
if
windowCosts
!=
nil
{
if
cost
,
ok
:=
windowCosts
[
acc
.
ID
];
ok
{
item
.
CurrentWindowCost
=
&
cost
}
}
}
// 添加活跃会话数(仅当启用时)
if
activeSessions
!=
nil
{
if
count
,
ok
:=
activeSessions
[
acc
.
ID
];
ok
{
item
.
ActiveSessions
=
&
count
}
}
result
[
i
]
=
item
}
}
response
.
Paginated
(
c
,
result
,
total
,
page
,
pageSize
)
response
.
Paginated
(
c
,
result
,
total
,
page
,
pageSize
)
...
...
backend/internal/handler/dto/mappers.go
View file @
dd7f2124
...
@@ -116,7 +116,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
...
@@ -116,7 +116,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if
a
==
nil
{
if
a
==
nil
{
return
nil
return
nil
}
}
return
&
Account
{
out
:=
&
Account
{
ID
:
a
.
ID
,
ID
:
a
.
ID
,
Name
:
a
.
Name
,
Name
:
a
.
Name
,
Notes
:
a
.
Notes
,
Notes
:
a
.
Notes
,
...
@@ -146,6 +146,24 @@ func AccountFromServiceShallow(a *service.Account) *Account {
...
@@ -146,6 +146,24 @@ func AccountFromServiceShallow(a *service.Account) *Account {
SessionWindowStatus
:
a
.
SessionWindowStatus
,
SessionWindowStatus
:
a
.
SessionWindowStatus
,
GroupIDs
:
a
.
GroupIDs
,
GroupIDs
:
a
.
GroupIDs
,
}
}
// 提取 5h 窗口费用控制和会话数量控制配置(仅 Anthropic OAuth/SetupToken 账号有效)
if
a
.
IsAnthropicOAuthOrSetupToken
()
{
if
limit
:=
a
.
GetWindowCostLimit
();
limit
>
0
{
out
.
WindowCostLimit
=
&
limit
}
if
reserve
:=
a
.
GetWindowCostStickyReserve
();
reserve
>
0
{
out
.
WindowCostStickyReserve
=
&
reserve
}
if
maxSessions
:=
a
.
GetMaxSessions
();
maxSessions
>
0
{
out
.
MaxSessions
=
&
maxSessions
}
if
idleTimeout
:=
a
.
GetSessionIdleTimeoutMinutes
();
idleTimeout
>
0
{
out
.
SessionIdleTimeoutMin
=
&
idleTimeout
}
}
return
out
}
}
func
AccountFromService
(
a
*
service
.
Account
)
*
Account
{
func
AccountFromService
(
a
*
service
.
Account
)
*
Account
{
...
...
backend/internal/handler/dto/types.go
View file @
dd7f2124
...
@@ -102,6 +102,16 @@ type Account struct {
...
@@ -102,6 +102,16 @@ type Account struct {
SessionWindowEnd
*
time
.
Time
`json:"session_window_end"`
SessionWindowEnd
*
time
.
Time
`json:"session_window_end"`
SessionWindowStatus
string
`json:"session_window_status"`
SessionWindowStatus
string
`json:"session_window_status"`
// 5h窗口费用控制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
WindowCostLimit
*
float64
`json:"window_cost_limit,omitempty"`
WindowCostStickyReserve
*
float64
`json:"window_cost_sticky_reserve,omitempty"`
// 会话数量控制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
MaxSessions
*
int
`json:"max_sessions,omitempty"`
SessionIdleTimeoutMin
*
int
`json:"session_idle_timeout_minutes,omitempty"`
Proxy
*
Proxy
`json:"proxy,omitempty"`
Proxy
*
Proxy
`json:"proxy,omitempty"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
...
...
backend/internal/handler/gateway_handler.go
View file @
dd7f2124
...
@@ -185,7 +185,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -185,7 +185,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus
:=
0
lastFailoverStatus
:=
0
for
{
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
""
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
@@ -320,7 +320,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -320,7 +320,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for
{
for
{
// 选择支持该模型的账号
// 选择支持该模型的账号
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
parsedReq
.
MetadataUserID
)
if
err
!=
nil
{
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
dd7f2124
...
@@ -226,7 +226,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -226,7 +226,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
lastFailoverStatus
:=
0
lastFailoverStatus
:=
0
for
{
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
modelName
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
modelName
,
failedAccountIDs
,
""
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
if
len
(
failedAccountIDs
)
==
0
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
...
...
backend/internal/repository/session_limit_cache.go
0 → 100644
View file @
dd7f2124
package
repository
import
(
"context"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 会话限制缓存常量定义
//
// 设计说明:
// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话:
// - Key: session_limit:account:{accountID}
// - Member: sessionUUID(从 metadata.user_id 中提取)
// - Score: Unix 时间戳(会话最后活跃时间)
//
// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
const
(
// 会话限制键前缀
// 格式: session_limit:account:{accountID}
sessionLimitKeyPrefix
=
"session_limit:account:"
// 窗口费用缓存键前缀
// 格式: window_cost:account:{accountID}
windowCostKeyPrefix
=
"window_cost:account:"
// 窗口费用缓存 TTL(30秒)
windowCostCacheTTL
=
30
*
time
.
Second
)
var
(
// registerSessionScript 注册会话活动
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = maxSessions
// ARGV[2] = idleTimeout(秒)
// ARGV[3] = sessionUUID
// 返回: 1 = 允许, 0 = 拒绝
registerSessionScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local maxSessions = tonumber(ARGV[1])
local idleTimeout = tonumber(ARGV[2])
local sessionUUID = ARGV[3]
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查会话是否已存在(支持刷新时间戳)
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
-- 会话已存在,刷新时间戳
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 检查是否达到会话数量上限
local count = redis.call('ZCARD', key)
if count < maxSessions then
-- 未达上限,添加新会话
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 达到上限,拒绝新会话
return 0
`
)
// refreshSessionScript 刷新会话时间戳
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
refreshSessionScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
-- 检查会话是否存在
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
end
return 1
`
)
// getActiveSessionCountScript 获取活跃会话数
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
getActiveSessionCountScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`
)
// isSessionActiveScript 检查会话是否活跃
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
isSessionActiveScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 获取会话的时间戳
local score = redis.call('ZSCORE', key, sessionUUID)
if score == false then
return 0
end
-- 检查是否过期
if tonumber(score) <= expireBefore then
return 0
end
return 1
`
)
)
type
sessionLimitCache
struct
{
rdb
*
redis
.
Client
defaultIdleTimeout
time
.
Duration
// 默认空闲超时(用于 GetActiveSessionCount)
}
// NewSessionLimitCache 创建会话限制缓存
// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
func
NewSessionLimitCache
(
rdb
*
redis
.
Client
,
defaultIdleTimeoutMinutes
int
)
service
.
SessionLimitCache
{
if
defaultIdleTimeoutMinutes
<=
0
{
defaultIdleTimeoutMinutes
=
5
// 默认 5 分钟
}
return
&
sessionLimitCache
{
rdb
:
rdb
,
defaultIdleTimeout
:
time
.
Duration
(
defaultIdleTimeoutMinutes
)
*
time
.
Minute
,
}
}
// sessionLimitKey 生成会话限制的 Redis 键
func
sessionLimitKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
sessionLimitKeyPrefix
,
accountID
)
}
// windowCostKey 生成窗口费用缓存的 Redis 键
func
windowCostKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
windowCostKeyPrefix
,
accountID
)
}
// RegisterSession 注册会话活动
func
(
c
*
sessionLimitCache
)
RegisterSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
maxSessions
int
,
idleTimeout
time
.
Duration
)
(
bool
,
error
)
{
if
sessionUUID
==
""
||
maxSessions
<=
0
{
return
true
,
nil
// 无效参数,默认允许
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
idleTimeout
.
Seconds
())
if
idleTimeoutSeconds
<=
0
{
idleTimeoutSeconds
=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
}
result
,
err
:=
registerSessionScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
maxSessions
,
idleTimeoutSeconds
,
sessionUUID
)
.
Int
()
if
err
!=
nil
{
return
true
,
err
// 失败开放:缓存错误时允许请求通过
}
return
result
==
1
,
nil
}
// RefreshSession 刷新会话时间戳
func
(
c
*
sessionLimitCache
)
RefreshSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
idleTimeout
time
.
Duration
)
error
{
if
sessionUUID
==
""
{
return
nil
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
idleTimeout
.
Seconds
())
if
idleTimeoutSeconds
<=
0
{
idleTimeoutSeconds
=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
}
_
,
err
:=
refreshSessionScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
,
sessionUUID
)
.
Result
()
return
err
}
// GetActiveSessionCount 获取活跃会话数
func
(
c
*
sessionLimitCache
)
GetActiveSessionCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
result
,
err
:=
getActiveSessionCountScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
)
.
Int
()
if
err
!=
nil
{
return
0
,
err
}
return
result
,
nil
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func
(
c
*
sessionLimitCache
)
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
int
),
nil
}
results
:=
make
(
map
[
int64
]
int
,
len
(
accountIDs
))
// 使用 pipeline 批量执行
pipe
:=
c
.
rdb
.
Pipeline
()
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
cmds
:=
make
(
map
[
int64
]
*
redis
.
Cmd
,
len
(
accountIDs
))
for
_
,
accountID
:=
range
accountIDs
{
key
:=
sessionLimitKey
(
accountID
)
cmds
[
accountID
]
=
getActiveSessionCountScript
.
Run
(
ctx
,
pipe
,
[]
string
{
key
},
idleTimeoutSeconds
)
}
// 执行 pipeline,即使部分失败也尝试获取成功的结果
_
,
_
=
pipe
.
Exec
(
ctx
)
for
accountID
,
cmd
:=
range
cmds
{
if
result
,
err
:=
cmd
.
Int
();
err
==
nil
{
results
[
accountID
]
=
result
}
}
return
results
,
nil
}
// IsSessionActive 检查会话是否活跃
func
(
c
*
sessionLimitCache
)
IsSessionActive
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
)
(
bool
,
error
)
{
if
sessionUUID
==
""
{
return
false
,
nil
}
key
:=
sessionLimitKey
(
accountID
)
idleTimeoutSeconds
:=
int
(
c
.
defaultIdleTimeout
.
Seconds
())
result
,
err
:=
isSessionActiveScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
idleTimeoutSeconds
,
sessionUUID
)
.
Int
()
if
err
!=
nil
{
return
false
,
err
}
return
result
==
1
,
nil
}
// ========== 5h窗口费用缓存实现 ==========
// GetWindowCost 获取缓存的窗口费用
func
(
c
*
sessionLimitCache
)
GetWindowCost
(
ctx
context
.
Context
,
accountID
int64
)
(
float64
,
bool
,
error
)
{
key
:=
windowCostKey
(
accountID
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Float64
()
if
err
==
redis
.
Nil
{
return
0
,
false
,
nil
// 缓存未命中
}
if
err
!=
nil
{
return
0
,
false
,
err
}
return
val
,
true
,
nil
}
// SetWindowCost 设置窗口费用缓存
func
(
c
*
sessionLimitCache
)
SetWindowCost
(
ctx
context
.
Context
,
accountID
int64
,
cost
float64
)
error
{
key
:=
windowCostKey
(
accountID
)
return
c
.
rdb
.
Set
(
ctx
,
key
,
cost
,
windowCostCacheTTL
)
.
Err
()
}
// GetWindowCostBatch 批量获取窗口费用缓存
func
(
c
*
sessionLimitCache
)
GetWindowCostBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
float64
,
error
)
{
if
len
(
accountIDs
)
==
0
{
return
make
(
map
[
int64
]
float64
),
nil
}
// 构建批量查询的 keys
keys
:=
make
([]
string
,
len
(
accountIDs
))
for
i
,
accountID
:=
range
accountIDs
{
keys
[
i
]
=
windowCostKey
(
accountID
)
}
// 使用 MGET 批量获取
vals
,
err
:=
c
.
rdb
.
MGet
(
ctx
,
keys
...
)
.
Result
()
if
err
!=
nil
{
return
nil
,
err
}
results
:=
make
(
map
[
int64
]
float64
,
len
(
accountIDs
))
for
i
,
val
:=
range
vals
{
if
val
==
nil
{
continue
// 缓存未命中
}
// 尝试解析为 float64
switch
v
:=
val
.
(
type
)
{
case
string
:
if
cost
,
err
:=
strconv
.
ParseFloat
(
v
,
64
);
err
==
nil
{
results
[
accountIDs
[
i
]]
=
cost
}
case
float64
:
results
[
accountIDs
[
i
]]
=
v
}
}
return
results
,
nil
}
backend/internal/repository/wire.go
View file @
dd7f2124
...
@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
...
@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
return
NewPricingRemoteClient
(
cfg
.
Update
.
ProxyURL
)
return
NewPricingRemoteClient
(
cfg
.
Update
.
ProxyURL
)
}
}
// ProvideSessionLimitCache 创建会话限制缓存
// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
func
ProvideSessionLimitCache
(
rdb
*
redis
.
Client
,
cfg
*
config
.
Config
)
service
.
SessionLimitCache
{
defaultIdleTimeoutMinutes
:=
5
// 默认 5 分钟空闲超时
if
cfg
!=
nil
&&
cfg
.
Gateway
.
SessionIdleTimeoutMinutes
>
0
{
defaultIdleTimeoutMinutes
=
cfg
.
Gateway
.
SessionIdleTimeoutMinutes
}
return
NewSessionLimitCache
(
rdb
,
defaultIdleTimeoutMinutes
)
}
// ProviderSet is the Wire provider set for all repositories
// ProviderSet is the Wire provider set for all repositories
var
ProviderSet
=
wire
.
NewSet
(
var
ProviderSet
=
wire
.
NewSet
(
NewUserRepository
,
NewUserRepository
,
...
@@ -61,6 +71,7 @@ var ProviderSet = wire.NewSet(
...
@@ -61,6 +71,7 @@ var ProviderSet = wire.NewSet(
NewTempUnschedCache
,
NewTempUnschedCache
,
NewTimeoutCounterCache
,
NewTimeoutCounterCache
,
ProvideConcurrencyCache
,
ProvideConcurrencyCache
,
ProvideSessionLimitCache
,
NewDashboardCache
,
NewDashboardCache
,
NewEmailCache
,
NewEmailCache
,
NewIdentityCache
,
NewIdentityCache
,
...
...
backend/internal/server/api_contract_test.go
View file @
dd7f2124
...
@@ -441,7 +441,7 @@ func newContractDeps(t *testing.T) *contractDeps {
...
@@ -441,7 +441,7 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
apiKeyHandler
:=
handler
.
NewAPIKeyHandler
(
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
usageHandler
:=
handler
.
NewUsageHandler
(
usageService
,
apiKeyService
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
,
nil
)
adminSettingHandler
:=
adminhandler
.
NewSettingHandler
(
settingService
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
adminAccountHandler
:=
adminhandler
.
NewAccountHandler
(
adminService
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
jwtAuth
:=
func
(
c
*
gin
.
Context
)
{
jwtAuth
:=
func
(
c
*
gin
.
Context
)
{
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
c
.
Set
(
string
(
middleware
.
ContextKeyUser
),
middleware
.
AuthSubject
{
...
...
backend/internal/service/account.go
View file @
dd7f2124
...
@@ -573,3 +573,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
...
@@ -573,3 +573,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
}
}
return
false
return
false
}
}
// WindowCostSchedulability 窗口费用调度状态
type
WindowCostSchedulability
int
const
(
// WindowCostSchedulable 可正常调度
WindowCostSchedulable
WindowCostSchedulability
=
iota
// WindowCostStickyOnly 仅允许粘性会话
WindowCostStickyOnly
// WindowCostNotSchedulable 完全不可调度
WindowCostNotSchedulable
)
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
func
(
a
*
Account
)
IsAnthropicOAuthOrSetupToken
()
bool
{
return
a
.
Platform
==
PlatformAnthropic
&&
(
a
.
Type
==
AccountTypeOAuth
||
a
.
Type
==
AccountTypeSetupToken
)
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func
(
a
*
Account
)
GetWindowCostLimit
()
float64
{
if
a
.
Extra
==
nil
{
return
0
}
if
v
,
ok
:=
a
.
Extra
[
"window_cost_limit"
];
ok
{
return
parseExtraFloat64
(
v
)
}
return
0
}
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
// 默认值为 10
func
(
a
*
Account
)
GetWindowCostStickyReserve
()
float64
{
if
a
.
Extra
==
nil
{
return
10.0
}
if
v
,
ok
:=
a
.
Extra
[
"window_cost_sticky_reserve"
];
ok
{
val
:=
parseExtraFloat64
(
v
)
if
val
>
0
{
return
val
}
}
return
10.0
}
// GetMaxSessions 获取最大并发会话数
// 返回 0 表示未启用
func
(
a
*
Account
)
GetMaxSessions
()
int
{
if
a
.
Extra
==
nil
{
return
0
}
if
v
,
ok
:=
a
.
Extra
[
"max_sessions"
];
ok
{
return
parseExtraInt
(
v
)
}
return
0
}
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
// 默认值为 5 分钟
func
(
a
*
Account
)
GetSessionIdleTimeoutMinutes
()
int
{
if
a
.
Extra
==
nil
{
return
5
}
if
v
,
ok
:=
a
.
Extra
[
"session_idle_timeout_minutes"
];
ok
{
val
:=
parseExtraInt
(
v
)
if
val
>
0
{
return
val
}
}
return
5
}
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
func
(
a
*
Account
)
CheckWindowCostSchedulability
(
currentWindowCost
float64
)
WindowCostSchedulability
{
limit
:=
a
.
GetWindowCostLimit
()
if
limit
<=
0
{
return
WindowCostSchedulable
}
if
currentWindowCost
<
limit
{
return
WindowCostSchedulable
}
stickyReserve
:=
a
.
GetWindowCostStickyReserve
()
if
currentWindowCost
<
limit
+
stickyReserve
{
return
WindowCostStickyOnly
}
return
WindowCostNotSchedulable
}
// parseExtraFloat64 从 extra 字段解析 float64 值
func
parseExtraFloat64
(
value
any
)
float64
{
switch
v
:=
value
.
(
type
)
{
case
float64
:
return
v
case
float32
:
return
float64
(
v
)
case
int
:
return
float64
(
v
)
case
int64
:
return
float64
(
v
)
case
json
.
Number
:
if
f
,
err
:=
v
.
Float64
();
err
==
nil
{
return
f
}
case
string
:
if
f
,
err
:=
strconv
.
ParseFloat
(
strings
.
TrimSpace
(
v
),
64
);
err
==
nil
{
return
f
}
}
return
0
}
// parseExtraInt 从 extra 字段解析 int 值
func
parseExtraInt
(
value
any
)
int
{
switch
v
:=
value
.
(
type
)
{
case
int
:
return
v
case
int64
:
return
int
(
v
)
case
float64
:
return
int
(
v
)
case
json
.
Number
:
if
i
,
err
:=
v
.
Int64
();
err
==
nil
{
return
int
(
i
)
}
case
string
:
if
i
,
err
:=
strconv
.
Atoi
(
strings
.
TrimSpace
(
v
));
err
==
nil
{
return
i
}
}
return
0
}
backend/internal/service/account_usage_service.go
View file @
dd7f2124
...
@@ -575,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
...
@@ -575,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
},
},
}
}
}
}
// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
// 用于账号列表页面显示当前窗口费用
func
(
s
*
AccountUsageService
)
GetAccountWindowStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
time
.
Time
)
(
*
usagestats
.
AccountStats
,
error
)
{
return
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
accountID
,
startTime
)
}
backend/internal/service/gateway_multiplatform_test.go
View file @
dd7f2124
...
@@ -1052,7 +1052,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
...
@@ -1052,7 +1052,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
// No concurrency service
concurrencyService
:
nil
,
// No concurrency service
}
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
NotNil
(
t
,
result
.
Account
)
...
@@ -1105,7 +1105,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
...
@@ -1105,7 +1105,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
// legacy path
concurrencyService
:
nil
,
// legacy path
}
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
sessionHash
,
"claude-b"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
sessionHash
,
"claude-b"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
NotNil
(
t
,
result
.
Account
)
...
@@ -1137,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
...
@@ -1137,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
concurrencyService
:
nil
,
}
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
NotNil
(
t
,
result
.
Account
)
...
@@ -1169,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
...
@@ -1169,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
}
}
excludedIDs
:=
map
[
int64
]
struct
{}{
1
:
{}}
excludedIDs
:=
map
[
int64
]
struct
{}{
1
:
{}}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
NotNil
(
t
,
result
.
Account
)
...
@@ -1203,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
...
@@ -1203,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
NotNil
(
t
,
result
.
Account
)
...
@@ -1239,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
...
@@ -1239,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
NotNil
(
t
,
result
.
Account
)
...
@@ -1266,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
...
@@ -1266,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
concurrencyService
:
nil
,
}
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
Error
(
t
,
err
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
...
@@ -1298,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
...
@@ -1298,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
concurrencyService
:
nil
,
}
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
NotNil
(
t
,
result
.
Account
)
...
@@ -1331,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
...
@@ -1331,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
concurrencyService
:
nil
,
}
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
NotNil
(
t
,
result
.
Account
)
...
...
backend/internal/service/gateway_service.go
View file @
dd7f2124
...
@@ -213,6 +213,7 @@ type GatewayService struct {
...
@@ -213,6 +213,7 @@ type GatewayService struct {
deferredService
*
DeferredService
deferredService
*
DeferredService
concurrencyService
*
ConcurrencyService
concurrencyService
*
ConcurrencyService
claudeTokenProvider
*
ClaudeTokenProvider
claudeTokenProvider
*
ClaudeTokenProvider
sessionLimitCache
SessionLimitCache
// 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
}
}
// NewGatewayService creates a new GatewayService
// NewGatewayService creates a new GatewayService
...
@@ -233,6 +234,7 @@ func NewGatewayService(
...
@@ -233,6 +234,7 @@ func NewGatewayService(
httpUpstream
HTTPUpstream
,
httpUpstream
HTTPUpstream
,
deferredService
*
DeferredService
,
deferredService
*
DeferredService
,
claudeTokenProvider
*
ClaudeTokenProvider
,
claudeTokenProvider
*
ClaudeTokenProvider
,
sessionLimitCache
SessionLimitCache
,
)
*
GatewayService
{
)
*
GatewayService
{
return
&
GatewayService
{
return
&
GatewayService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
...
@@ -251,6 +253,7 @@ func NewGatewayService(
...
@@ -251,6 +253,7 @@ func NewGatewayService(
httpUpstream
:
httpUpstream
,
httpUpstream
:
httpUpstream
,
deferredService
:
deferredService
,
deferredService
:
deferredService
,
claudeTokenProvider
:
claudeTokenProvider
,
claudeTokenProvider
:
claudeTokenProvider
,
sessionLimitCache
:
sessionLimitCache
,
}
}
}
}
...
@@ -816,8 +819,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
...
@@ -816,8 +819,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
func
(
s
*
GatewayService
)
SelectAccountWithLoadAwareness
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
AccountSelectionResult
,
error
)
{
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
func
(
s
*
GatewayService
)
SelectAccountWithLoadAwareness
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
metadataUserID
string
)
(
*
AccountSelectionResult
,
error
)
{
cfg
:=
s
.
schedulingConfig
()
cfg
:=
s
.
schedulingConfig
()
// 提取会话 UUID(用于会话数量限制)
sessionUUID
:=
extractSessionUUID
(
metadataUserID
)
var
stickyAccountID
int64
var
stickyAccountID
int64
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
);
err
==
nil
{
if
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
);
err
==
nil
{
...
@@ -936,7 +943,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -936,7 +943,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
len
(
routingAccountIDs
)
>
0
&&
s
.
concurrencyService
!=
nil
{
if
len
(
routingAccountIDs
)
>
0
&&
s
.
concurrencyService
!=
nil
{
// 1. 过滤出路由列表中可调度的账号
// 1. 过滤出路由列表中可调度的账号
var
routingCandidates
[]
*
Account
var
routingCandidates
[]
*
Account
var
filteredExcluded
,
filteredMissing
,
filteredUnsched
,
filteredPlatform
,
filteredModelScope
,
filteredModelMapping
int
var
filteredExcluded
,
filteredMissing
,
filteredUnsched
,
filteredPlatform
,
filteredModelScope
,
filteredModelMapping
,
filteredWindowCost
int
for
_
,
routingAccountID
:=
range
routingAccountIDs
{
for
_
,
routingAccountID
:=
range
routingAccountIDs
{
if
isExcluded
(
routingAccountID
)
{
if
isExcluded
(
routingAccountID
)
{
filteredExcluded
++
filteredExcluded
++
...
@@ -963,13 +970,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -963,13 +970,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
filteredModelMapping
++
filteredModelMapping
++
continue
continue
}
}
// 窗口费用检查(非粘性会话路径)
if
!
s
.
isAccountSchedulableForWindowCost
(
ctx
,
account
,
false
)
{
filteredWindowCost
++
continue
}
routingCandidates
=
append
(
routingCandidates
,
account
)
routingCandidates
=
append
(
routingCandidates
,
account
)
}
}
if
s
.
debugModelRoutingEnabled
()
{
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d)"
,
log
.
Printf
(
"[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d
window_cost=%d
)"
,
derefGroupID
(
groupID
),
requestedModel
,
len
(
routingAccountIDs
),
len
(
routingCandidates
),
derefGroupID
(
groupID
),
requestedModel
,
len
(
routingAccountIDs
),
len
(
routingCandidates
),
filteredExcluded
,
filteredMissing
,
filteredUnsched
,
filteredPlatform
,
filteredModelScope
,
filteredModelMapping
)
filteredExcluded
,
filteredMissing
,
filteredUnsched
,
filteredPlatform
,
filteredModelScope
,
filteredModelMapping
,
filteredWindowCost
)
}
}
if
len
(
routingCandidates
)
>
0
{
if
len
(
routingCandidates
)
>
0
{
...
@@ -982,9 +994,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -982,9 +994,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
stickyAccount
.
IsSchedulable
()
&&
if
stickyAccount
.
IsSchedulable
()
&&
s
.
isAccountAllowedForPlatform
(
stickyAccount
,
platform
,
useMixed
)
&&
s
.
isAccountAllowedForPlatform
(
stickyAccount
,
platform
,
useMixed
)
&&
stickyAccount
.
IsSchedulableForModel
(
requestedModel
)
&&
stickyAccount
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
stickyAccount
,
requestedModel
))
{
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
stickyAccount
,
requestedModel
))
&&
s
.
isAccountSchedulableForWindowCost
(
ctx
,
stickyAccount
,
true
)
{
// 粘性会话窗口费用检查
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
stickyAccountID
,
stickyAccount
.
Concurrency
)
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
stickyAccountID
,
stickyAccount
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
stickyAccount
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位
// 继续到负载感知选择
}
else
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
if
s
.
debugModelRoutingEnabled
()
{
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
stickyAccountID
)
log
.
Printf
(
"[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
stickyAccountID
)
...
@@ -995,6 +1013,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -995,6 +1013,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
ReleaseFunc
:
result
.
ReleaseFunc
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
},
nil
}
}
}
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
stickyAccountID
)
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
stickyAccountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
...
@@ -1066,6 +1085,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1066,6 +1085,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for
_
,
item
:=
range
routingAvailable
{
for
_
,
item
:=
range
routingAvailable
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
item
.
account
.
ID
,
item
.
account
.
Concurrency
)
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
item
.
account
.
ID
,
item
.
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
item
.
account
.
ID
,
stickySessionTTL
)
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
item
.
account
.
ID
,
stickySessionTTL
)
}
}
...
@@ -1108,9 +1132,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1108,9 +1132,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
ok
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
if
ok
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
s
.
isAccountAllowedForPlatform
(
account
,
platform
,
useMixed
)
&&
s
.
isAccountAllowedForPlatform
(
account
,
platform
,
useMixed
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
&&
s
.
isAccountSchedulableForWindowCost
(
ctx
,
account
,
true
)
{
// 粘性会话窗口费用检查
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续到 Layer 2
}
else
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
return
&
AccountSelectionResult
{
return
&
AccountSelectionResult
{
Account
:
account
,
Account
:
account
,
...
@@ -1118,6 +1147,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1118,6 +1147,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
ReleaseFunc
:
result
.
ReleaseFunc
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
},
nil
}
}
}
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
...
@@ -1157,6 +1187,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1157,6 +1187,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
continue
}
}
// 窗口费用检查(非粘性会话路径)
if
!
s
.
isAccountSchedulableForWindowCost
(
ctx
,
acc
,
false
)
{
continue
}
candidates
=
append
(
candidates
,
acc
)
candidates
=
append
(
candidates
,
acc
)
}
}
...
@@ -1174,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1174,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap
,
err
:=
s
.
concurrencyService
.
GetAccountsLoadBatch
(
ctx
,
accountLoads
)
loadMap
,
err
:=
s
.
concurrencyService
.
GetAccountsLoadBatch
(
ctx
,
accountLoads
)
if
err
!=
nil
{
if
err
!=
nil
{
if
result
,
ok
:=
s
.
tryAcquireByLegacyOrder
(
ctx
,
candidates
,
groupID
,
sessionHash
,
preferOAuth
);
ok
{
if
result
,
ok
:=
s
.
tryAcquireByLegacyOrder
(
ctx
,
candidates
,
groupID
,
sessionHash
,
preferOAuth
,
sessionUUID
);
ok
{
return
result
,
nil
return
result
,
nil
}
}
}
else
{
}
else
{
...
@@ -1223,6 +1257,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1223,6 +1257,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for
_
,
item
:=
range
available
{
for
_
,
item
:=
range
available
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
item
.
account
.
ID
,
item
.
account
.
Concurrency
)
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
item
.
account
.
ID
,
item
.
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
item
.
account
.
ID
,
stickySessionTTL
)
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
item
.
account
.
ID
,
stickySessionTTL
)
}
}
...
@@ -1252,13 +1291,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
...
@@ -1252,13 +1291,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return
nil
,
errors
.
New
(
"no available accounts"
)
return
nil
,
errors
.
New
(
"no available accounts"
)
}
}
func
(
s
*
GatewayService
)
tryAcquireByLegacyOrder
(
ctx
context
.
Context
,
candidates
[]
*
Account
,
groupID
*
int64
,
sessionHash
string
,
preferOAuth
bool
)
(
*
AccountSelectionResult
,
bool
)
{
func
(
s
*
GatewayService
)
tryAcquireByLegacyOrder
(
ctx
context
.
Context
,
candidates
[]
*
Account
,
groupID
*
int64
,
sessionHash
string
,
preferOAuth
bool
,
sessionUUID
string
)
(
*
AccountSelectionResult
,
bool
)
{
ordered
:=
append
([]
*
Account
(
nil
),
candidates
...
)
ordered
:=
append
([]
*
Account
(
nil
),
candidates
...
)
sortAccountsByPriorityAndLastUsed
(
ordered
,
preferOAuth
)
sortAccountsByPriorityAndLastUsed
(
ordered
,
preferOAuth
)
for
_
,
acc
:=
range
ordered
{
for
_
,
acc
:=
range
ordered
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
acc
.
ID
,
acc
.
Concurrency
)
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
acc
.
ID
,
acc
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
acc
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
acc
.
ID
,
stickySessionTTL
)
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
acc
.
ID
,
stickySessionTTL
)
}
}
...
@@ -1490,6 +1534,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
...
@@ -1490,6 +1534,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
return
s
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
accountID
,
maxConcurrency
)
return
s
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
accountID
,
maxConcurrency
)
}
}
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示可调度,false 表示不可调度
func
(
s
*
GatewayService
)
isAccountSchedulableForWindowCost
(
ctx
context
.
Context
,
account
*
Account
,
isSticky
bool
)
bool
{
// 只检查 Anthropic OAuth/SetupToken 账号
if
!
account
.
IsAnthropicOAuthOrSetupToken
()
{
return
true
}
limit
:=
account
.
GetWindowCostLimit
()
if
limit
<=
0
{
return
true
// 未启用窗口费用限制
}
// 尝试从缓存获取窗口费用
var
currentCost
float64
if
s
.
sessionLimitCache
!=
nil
{
if
cost
,
hit
,
err
:=
s
.
sessionLimitCache
.
GetWindowCost
(
ctx
,
account
.
ID
);
err
==
nil
&&
hit
{
currentCost
=
cost
goto
checkSchedulability
}
}
// 缓存未命中,从数据库查询
{
var
startTime
time
.
Time
if
account
.
SessionWindowStart
!=
nil
{
startTime
=
*
account
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
stats
,
err
:=
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
account
.
ID
,
startTime
)
if
err
!=
nil
{
// 失败开放:查询失败时允许调度
return
true
}
// 使用标准费用(不含账号倍率)
currentCost
=
stats
.
StandardCost
// 设置缓存(忽略错误)
if
s
.
sessionLimitCache
!=
nil
{
_
=
s
.
sessionLimitCache
.
SetWindowCost
(
ctx
,
account
.
ID
,
currentCost
)
}
}
checkSchedulability
:
schedulability
:=
account
.
CheckWindowCostSchedulability
(
currentCost
)
switch
schedulability
{
case
WindowCostSchedulable
:
return
true
case
WindowCostStickyOnly
:
return
isSticky
case
WindowCostNotSchedulable
:
return
false
}
return
true
}
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
func
(
s
*
GatewayService
)
checkAndRegisterSession
(
ctx
context
.
Context
,
account
*
Account
,
sessionUUID
string
)
bool
{
// 只检查 Anthropic OAuth/SetupToken 账号
if
!
account
.
IsAnthropicOAuthOrSetupToken
()
{
return
true
}
maxSessions
:=
account
.
GetMaxSessions
()
if
maxSessions
<=
0
||
sessionUUID
==
""
{
return
true
// 未启用会话限制或无会话ID
}
if
s
.
sessionLimitCache
==
nil
{
return
true
// 缓存不可用时允许通过
}
idleTimeout
:=
time
.
Duration
(
account
.
GetSessionIdleTimeoutMinutes
())
*
time
.
Minute
allowed
,
err
:=
s
.
sessionLimitCache
.
RegisterSession
(
ctx
,
account
.
ID
,
sessionUUID
,
maxSessions
,
idleTimeout
)
if
err
!=
nil
{
// 失败开放:缓存错误时允许通过
return
true
}
return
allowed
}
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
// 格式: user_{64位hex}_account__session_{uuid}
func
extractSessionUUID
(
metadataUserID
string
)
string
{
if
metadataUserID
==
""
{
return
""
}
if
match
:=
sessionIDRegex
.
FindStringSubmatch
(
metadataUserID
);
len
(
match
)
>
1
{
return
match
[
1
]
}
return
""
}
func
(
s
*
GatewayService
)
getSchedulableAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Account
,
error
)
{
func
(
s
*
GatewayService
)
getSchedulableAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Account
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
if
s
.
schedulerSnapshot
!=
nil
{
return
s
.
schedulerSnapshot
.
GetAccount
(
ctx
,
accountID
)
return
s
.
schedulerSnapshot
.
GetAccount
(
ctx
,
accountID
)
...
@@ -2384,9 +2529,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
...
@@ -2384,9 +2529,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart
:=
time
.
Now
()
retryStart
:=
time
.
Now
()
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
// Capture upstream request body for ops retry of this attempt.
// Capture upstream request body for ops retry of this attempt.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
...
...
backend/internal/service/openai_gateway_service.go
View file @
dd7f2124
...
@@ -1067,16 +1067,30 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1067,16 +1067,30 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt
:=
time
.
Now
()
lastDataAt
:=
time
.
Now
()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
// 仅发送一次错误事件,避免多次写入导致协议混乱。
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
errorEventSent
:=
false
errorEventSent
:=
false
clientDisconnected
:=
false
// 客户端断开后继续 drain 上游以收集 usage
sendErrorEvent
:=
func
(
reason
string
)
{
sendErrorEvent
:=
func
(
reason
string
)
{
if
errorEventSent
{
if
errorEventSent
||
clientDisconnected
{
return
return
}
}
errorEventSent
=
true
errorEventSent
=
true
_
,
_
=
fmt
.
Fprintf
(
w
,
"event: error
\n
data: {
\"
error
\"
:
\"
%s
\"
}
\n\n
"
,
reason
)
payload
:=
map
[
string
]
any
{
"type"
:
"error"
,
"sequence_number"
:
0
,
"error"
:
map
[
string
]
any
{
"type"
:
"upstream_error"
,
"message"
:
reason
,
"code"
:
reason
,
},
}
if
b
,
err
:=
json
.
Marshal
(
payload
);
err
==
nil
{
_
,
_
=
fmt
.
Fprintf
(
w
,
"data: %s
\n\n
"
,
b
)
flusher
.
Flush
()
flusher
.
Flush
()
}
}
}
needModelReplace
:=
originalModel
!=
mappedModel
needModelReplace
:=
originalModel
!=
mappedModel
...
@@ -1087,6 +1101,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1087,6 +1101,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
}
if
ev
.
err
!=
nil
{
if
ev
.
err
!=
nil
{
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if
errors
.
Is
(
ev
.
err
,
context
.
Canceled
)
||
errors
.
Is
(
ev
.
err
,
context
.
DeadlineExceeded
)
{
log
.
Printf
(
"Context canceled during streaming, returning collected usage"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if
clientDisconnected
{
log
.
Printf
(
"Upstream read error after client disconnect: %v, returning collected usage"
,
ev
.
err
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
if
errors
.
Is
(
ev
.
err
,
bufio
.
ErrTooLong
)
{
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
log
.
Printf
(
"SSE line too long: account=%d max_size=%d error=%v"
,
account
.
ID
,
maxLineSize
,
ev
.
err
)
sendErrorEvent
(
"response_too_large"
)
sendErrorEvent
(
"response_too_large"
)
...
@@ -1110,15 +1135,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1110,15 +1135,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if
correctedData
,
corrected
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
);
corrected
{
if
correctedData
,
corrected
:=
s
.
toolCorrector
.
CorrectToolCallsInSSEData
(
data
);
corrected
{
data
=
correctedData
line
=
"data: "
+
correctedData
line
=
"data: "
+
correctedData
}
}
// Forward line
// 写入客户端(客户端断开后继续 drain 上游)
if
!
clientDisconnected
{
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
clientDisconnected
=
true
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
}
}
else
{
flusher
.
Flush
()
flusher
.
Flush
()
}
}
// Record first token time
// Record first token time
if
firstTokenMs
==
nil
&&
data
!=
""
&&
data
!=
"[DONE]"
{
if
firstTokenMs
==
nil
&&
data
!=
""
&&
data
!=
"[DONE]"
{
...
@@ -1128,18 +1157,25 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1128,18 +1157,25 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
s
.
parseSSEUsage
(
data
,
usage
)
s
.
parseSSEUsage
(
data
,
usage
)
}
else
{
}
else
{
// Forward non-data lines as-is
// Forward non-data lines as-is
if
!
clientDisconnected
{
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
if
_
,
err
:=
fmt
.
Fprintf
(
w
,
"%s
\n
"
,
line
);
err
!=
nil
{
sendErrorEvent
(
"write_failed"
)
clientDisconnected
=
true
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
}
}
else
{
flusher
.
Flush
()
flusher
.
Flush
()
}
}
}
}
case
<-
intervalCh
:
case
<-
intervalCh
:
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
lastRead
:=
time
.
Unix
(
0
,
atomic
.
LoadInt64
(
&
lastReadAt
))
if
time
.
Since
(
lastRead
)
<
streamInterval
{
if
time
.
Since
(
lastRead
)
<
streamInterval
{
continue
continue
}
}
if
clientDisconnected
{
log
.
Printf
(
"Upstream timeout after client disconnect, returning collected usage"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
nil
}
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
log
.
Printf
(
"Stream data interval timeout: account=%d model=%s interval=%s"
,
account
.
ID
,
originalModel
,
streamInterval
)
// 处理流超时,可能标记账户为临时不可调度或错误状态
// 处理流超时,可能标记账户为临时不可调度或错误状态
if
s
.
rateLimitService
!=
nil
{
if
s
.
rateLimitService
!=
nil
{
...
@@ -1149,11 +1185,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
...
@@ -1149,11 +1185,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
case
<-
keepaliveCh
:
case
<-
keepaliveCh
:
if
clientDisconnected
{
continue
}
if
time
.
Since
(
lastDataAt
)
<
keepaliveInterval
{
if
time
.
Since
(
lastDataAt
)
<
keepaliveInterval
{
continue
continue
}
}
if
_
,
err
:=
fmt
.
Fprint
(
w
,
":
\n\n
"
);
err
!=
nil
{
if
_
,
err
:=
fmt
.
Fprint
(
w
,
":
\n\n
"
);
err
!=
nil
{
return
&
openaiStreamingResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
err
clientDisconnected
=
true
log
.
Printf
(
"Client disconnected during streaming, continuing to drain upstream for billing"
)
continue
}
}
flusher
.
Flush
()
flusher
.
Flush
()
}
}
...
...
backend/internal/service/openai_gateway_service_test.go
View file @
dd7f2124
...
@@ -33,6 +33,25 @@ type stubConcurrencyCache struct {
...
@@ -33,6 +33,25 @@ type stubConcurrencyCache struct {
ConcurrencyCache
ConcurrencyCache
}
}
type
cancelReadCloser
struct
{}
func
(
c
cancelReadCloser
)
Read
(
p
[]
byte
)
(
int
,
error
)
{
return
0
,
context
.
Canceled
}
func
(
c
cancelReadCloser
)
Close
()
error
{
return
nil
}
type
failingGinWriter
struct
{
gin
.
ResponseWriter
failAfter
int
writes
int
}
func
(
w
*
failingGinWriter
)
Write
(
p
[]
byte
)
(
int
,
error
)
{
if
w
.
writes
>=
w
.
failAfter
{
return
0
,
errors
.
New
(
"write failed"
)
}
w
.
writes
++
return
w
.
ResponseWriter
.
Write
(
p
)
}
func
(
c
stubConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
func
(
c
stubConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
return
true
,
nil
return
true
,
nil
}
}
...
@@ -169,8 +188,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
...
@@ -169,8 +188,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
if
err
==
nil
||
!
strings
.
Contains
(
err
.
Error
(),
"stream data interval timeout"
)
{
if
err
==
nil
||
!
strings
.
Contains
(
err
.
Error
(),
"stream data interval timeout"
)
{
t
.
Fatalf
(
"expected stream timeout error, got %v"
,
err
)
t
.
Fatalf
(
"expected stream timeout error, got %v"
,
err
)
}
}
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_timeout"
)
{
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"
\"
type
\"
:
\"
error
\"
"
)
||
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_timeout"
)
{
t
.
Fatalf
(
"expected stream_timeout SSE error, got %q"
,
rec
.
Body
.
String
())
t
.
Fatalf
(
"expected OpenAI-compatible error SSE event, got %q"
,
rec
.
Body
.
String
())
}
}
func
TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
cancel
()
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
.
WithContext
(
ctx
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
cancelReadCloser
{},
Header
:
http
.
Header
{},
}
_
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
if
err
!=
nil
{
t
.
Fatalf
(
"expected nil error, got %v"
,
err
)
}
if
strings
.
Contains
(
rec
.
Body
.
String
(),
"event: error"
)
||
strings
.
Contains
(
rec
.
Body
.
String
(),
"stream_read_error"
)
{
t
.
Fatalf
(
"expected no injected SSE error event, got %q"
,
rec
.
Body
.
String
())
}
}
func
TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
cfg
:=
&
config
.
Config
{
Gateway
:
config
.
GatewayConfig
{
StreamDataIntervalTimeout
:
0
,
StreamKeepaliveInterval
:
0
,
MaxLineSize
:
defaultMaxLineSize
,
},
}
svc
:=
&
OpenAIGatewayService
{
cfg
:
cfg
}
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/"
,
nil
)
c
.
Writer
=
&
failingGinWriter
{
ResponseWriter
:
c
.
Writer
,
failAfter
:
0
}
pr
,
pw
:=
io
.
Pipe
()
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusOK
,
Body
:
pr
,
Header
:
http
.
Header
{},
}
go
func
()
{
defer
func
()
{
_
=
pw
.
Close
()
}()
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.in_progress
\"
,
\"
response
\"
:{}}
\n\n
"
))
_
,
_
=
pw
.
Write
([]
byte
(
"data: {
\"
type
\"
:
\"
response.completed
\"
,
\"
response
\"
:{
\"
usage
\"
:{
\"
input_tokens
\"
:3,
\"
output_tokens
\"
:5,
\"
input_tokens_details
\"
:{
\"
cached_tokens
\"
:1}}}}
\n\n
"
))
}()
result
,
err
:=
svc
.
handleStreamingResponse
(
c
.
Request
.
Context
(),
resp
,
c
,
&
Account
{
ID
:
1
},
time
.
Now
(),
"model"
,
"model"
)
_
=
pr
.
Close
()
if
err
!=
nil
{
t
.
Fatalf
(
"expected nil error, got %v"
,
err
)
}
if
result
==
nil
||
result
.
usage
==
nil
{
t
.
Fatalf
(
"expected usage result"
)
}
if
result
.
usage
.
InputTokens
!=
3
||
result
.
usage
.
OutputTokens
!=
5
||
result
.
usage
.
CacheReadInputTokens
!=
1
{
t
.
Fatalf
(
"unexpected usage: %+v"
,
*
result
.
usage
)
}
if
strings
.
Contains
(
rec
.
Body
.
String
(),
"event: error"
)
||
strings
.
Contains
(
rec
.
Body
.
String
(),
"write_failed"
)
{
t
.
Fatalf
(
"expected no injected SSE error event, got %q"
,
rec
.
Body
.
String
())
}
}
}
}
...
@@ -209,8 +305,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
...
@@ -209,8 +305,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
if
!
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
if
!
errors
.
Is
(
err
,
bufio
.
ErrTooLong
)
{
t
.
Fatalf
(
"expected ErrTooLong, got %v"
,
err
)
t
.
Fatalf
(
"expected ErrTooLong, got %v"
,
err
)
}
}
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"response_too_large"
)
{
if
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"
\"
type
\"
:
\"
error
\"
"
)
||
!
strings
.
Contains
(
rec
.
Body
.
String
(),
"response_too_large"
)
{
t
.
Fatalf
(
"expected
response_too_large
SSE e
rror
, got %q"
,
rec
.
Body
.
String
())
t
.
Fatalf
(
"expected
OpenAI-compatible error
SSE e
vent
, got %q"
,
rec
.
Body
.
String
())
}
}
}
}
...
...
backend/internal/service/ops_retry.go
View file @
dd7f2124
...
@@ -514,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
...
@@ -514,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
if
s
.
gatewayService
==
nil
{
if
s
.
gatewayService
==
nil
{
return
nil
,
fmt
.
Errorf
(
"gateway service not available"
)
return
nil
,
fmt
.
Errorf
(
"gateway service not available"
)
}
}
return
s
.
gatewayService
.
SelectAccountWithLoadAwareness
(
ctx
,
groupID
,
""
,
model
,
excludedIDs
)
return
s
.
gatewayService
.
SelectAccountWithLoadAwareness
(
ctx
,
groupID
,
""
,
model
,
excludedIDs
,
""
)
// 重试不使用会话限制
default
:
default
:
return
nil
,
fmt
.
Errorf
(
"unsupported retry type: %s"
,
reqType
)
return
nil
,
fmt
.
Errorf
(
"unsupported retry type: %s"
,
reqType
)
}
}
...
...
backend/internal/service/session_limit_cache.go
0 → 100644
View file @
dd7f2124
package
service
import
(
"context"
"time"
)
// SessionLimitCache 管理账号级别的活跃会话跟踪
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制
//
// Key 格式: session_limit:account:{accountID}
// 数据结构: Sorted Set (member=sessionUUID, score=timestamp)
//
// 会话在空闲超时后自动过期,无需手动清理
type
SessionLimitCache
interface
{
// RegisterSession 注册会话活动
// - 如果会话已存在,刷新其时间戳并返回 true
// - 如果会话不存在且活跃会话数 < maxSessions,添加新会话并返回 true
// - 如果会话不存在且活跃会话数 >= maxSessions,返回 false(拒绝)
//
// 参数:
// accountID: 账号 ID
// sessionUUID: 从 metadata.user_id 中提取的会话 UUID
// maxSessions: 最大并发会话数限制
// idleTimeout: 会话空闲超时时间
//
// 返回:
// allowed: true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
// error: 操作错误
RegisterSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
maxSessions
int
,
idleTimeout
time
.
Duration
)
(
allowed
bool
,
err
error
)
// RefreshSession 刷新现有会话的时间戳
// 用于活跃会话保持活动状态
RefreshSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
idleTimeout
time
.
Duration
)
error
// GetActiveSessionCount 获取当前活跃会话数
// 返回未过期的会话数量
GetActiveSessionCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// 返回 map[accountID]count,查询失败的账号不在 map 中
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
// IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
)
(
bool
,
error
)
// ========== 5h窗口费用缓存 ==========
// Key 格式: window_cost:account:{accountID}
// 用于缓存账号在当前5h窗口内的标准费用,减少数据库聚合查询压力
// GetWindowCost 获取缓存的窗口费用
// 返回 (cost, true, nil) 如果缓存命中
// 返回 (0, false, nil) 如果缓存未命中
// 返回 (0, false, err) 如果发生错误
GetWindowCost
(
ctx
context
.
Context
,
accountID
int64
)
(
cost
float64
,
hit
bool
,
err
error
)
// SetWindowCost 设置窗口费用缓存
SetWindowCost
(
ctx
context
.
Context
,
accountID
int64
,
cost
float64
)
error
// GetWindowCostBatch 批量获取窗口费用缓存
// 返回 map[accountID]cost,缓存未命中的账号不在 map 中
GetWindowCostBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
float64
,
error
)
}
deploy/Caddyfile
View file @
dd7f2124
...
@@ -33,6 +33,22 @@
...
@@ -33,6 +33,22 @@
# 修改为你的域名
# 修改为你的域名
example.com {
example.com {
# =========================================================================
# 静态资源长期缓存(高优先级,放在最前面)
# 带 hash 的文件可以永久缓存,浏览器和 CDN 都会缓存
# =========================================================================
@static {
path /assets/*
path /logo.png
path /favicon.ico
}
header @static {
Cache-Control "public, max-age=31536000, immutable"
# 移除可能干扰缓存的头
-Pragma
-Expires
}
# =========================================================================
# =========================================================================
# TLS 安全配置
# TLS 安全配置
# =========================================================================
# =========================================================================
...
...
frontend/package-lock.json
View file @
dd7f2124
This diff is collapsed.
Click to expand it.
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment