"frontend/src/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "c5781c69bb8490e8a55a078df73dfd8865307acf"
Commit dd7f2124 authored by cyhhao's avatar cyhhao
Browse files

merge: resolve conflicts with main

parents 49be9d08 bba5b3c0
...@@ -118,7 +118,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -118,7 +118,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache)
oAuthHandler := admin.NewOAuthHandler(oAuthService) oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
...@@ -140,7 +141,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -140,7 +141,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
......
...@@ -234,6 +234,10 @@ type GatewayConfig struct { ...@@ -234,6 +234,10 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期 // 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能
// 空闲超过此时间的会话将被自动释放
SessionIdleTimeoutMinutes int `mapstructure:"session_idle_timeout_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用 // StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"` StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
......
...@@ -44,6 +44,7 @@ type AccountHandler struct { ...@@ -44,6 +44,7 @@ type AccountHandler struct {
accountTestService *service.AccountTestService accountTestService *service.AccountTestService
concurrencyService *service.ConcurrencyService concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService crsSyncService *service.CRSSyncService
sessionLimitCache service.SessionLimitCache
} }
// NewAccountHandler creates a new admin account handler // NewAccountHandler creates a new admin account handler
...@@ -58,6 +59,7 @@ func NewAccountHandler( ...@@ -58,6 +59,7 @@ func NewAccountHandler(
accountTestService *service.AccountTestService, accountTestService *service.AccountTestService,
concurrencyService *service.ConcurrencyService, concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService, crsSyncService *service.CRSSyncService,
sessionLimitCache service.SessionLimitCache,
) *AccountHandler { ) *AccountHandler {
return &AccountHandler{ return &AccountHandler{
adminService: adminService, adminService: adminService,
...@@ -70,6 +72,7 @@ func NewAccountHandler( ...@@ -70,6 +72,7 @@ func NewAccountHandler(
accountTestService: accountTestService, accountTestService: accountTestService,
concurrencyService: concurrencyService, concurrencyService: concurrencyService,
crsSyncService: crsSyncService, crsSyncService: crsSyncService,
sessionLimitCache: sessionLimitCache,
} }
} }
...@@ -130,6 +133,9 @@ type BulkUpdateAccountsRequest struct { ...@@ -130,6 +133,9 @@ type BulkUpdateAccountsRequest struct {
type AccountWithConcurrency struct { type AccountWithConcurrency struct {
*dto.Account *dto.Account
CurrentConcurrency int `json:"current_concurrency"` CurrentConcurrency int `json:"current_concurrency"`
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
} }
// List handles listing all accounts with pagination // List handles listing all accounts with pagination
...@@ -164,13 +170,89 @@ func (h *AccountHandler) List(c *gin.Context) { ...@@ -164,13 +170,89 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts = make(map[int64]int) concurrencyCounts = make(map[int64]int)
} }
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
for i := range accounts {
acc := &accounts[i]
if acc.IsAnthropicOAuthOrSetupToken() {
if acc.GetWindowCostLimit() > 0 {
windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
}
if acc.GetMaxSessions() > 0 {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
}
}
}
// 并行获取窗口费用和活跃会话数
var windowCosts map[int64]float64
var activeSessions map[int64]int
// 获取活跃会话数(批量查询)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
if activeSessions == nil {
activeSessions = make(map[int64]int)
}
}
// 获取窗口费用(并行查询)
if len(windowCostAccountIDs) > 0 {
windowCosts = make(map[int64]float64)
var mu sync.Mutex
g, gctx := errgroup.WithContext(c.Request.Context())
g.SetLimit(10) // 限制并发数
for i := range accounts {
acc := &accounts[i]
if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
continue
}
accCopy := acc // 闭包捕获
g.Go(func() error {
var startTime time.Time
if accCopy.SessionWindowStart != nil {
startTime = *accCopy.SessionWindowStart
} else {
startTime = time.Now().Add(-5 * time.Hour)
}
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
if err == nil && stats != nil {
mu.Lock()
windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
mu.Unlock()
}
return nil // 不返回错误,允许部分失败
})
}
_ = g.Wait()
}
// Build response with concurrency info // Build response with concurrency info
result := make([]AccountWithConcurrency, len(accounts)) result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts { for i := range accounts {
result[i] = AccountWithConcurrency{ acc := &accounts[i]
Account: dto.AccountFromService(&accounts[i]), item := AccountWithConcurrency{
CurrentConcurrency: concurrencyCounts[accounts[i].ID], Account: dto.AccountFromService(acc),
CurrentConcurrency: concurrencyCounts[acc.ID],
} }
// 添加窗口费用(仅当启用时)
if windowCosts != nil {
if cost, ok := windowCosts[acc.ID]; ok {
item.CurrentWindowCost = &cost
}
}
// 添加活跃会话数(仅当启用时)
if activeSessions != nil {
if count, ok := activeSessions[acc.ID]; ok {
item.ActiveSessions = &count
}
}
result[i] = item
} }
response.Paginated(c, result, total, page, pageSize) response.Paginated(c, result, total, page, pageSize)
......
...@@ -116,7 +116,7 @@ func AccountFromServiceShallow(a *service.Account) *Account { ...@@ -116,7 +116,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil { if a == nil {
return nil return nil
} }
return &Account{ out := &Account{
ID: a.ID, ID: a.ID,
Name: a.Name, Name: a.Name,
Notes: a.Notes, Notes: a.Notes,
...@@ -146,6 +146,24 @@ func AccountFromServiceShallow(a *service.Account) *Account { ...@@ -146,6 +146,24 @@ func AccountFromServiceShallow(a *service.Account) *Account {
SessionWindowStatus: a.SessionWindowStatus, SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs, GroupIDs: a.GroupIDs,
} }
// 提取 5h 窗口费用控制和会话数量控制配置(仅 Anthropic OAuth/SetupToken 账号有效)
if a.IsAnthropicOAuthOrSetupToken() {
if limit := a.GetWindowCostLimit(); limit > 0 {
out.WindowCostLimit = &limit
}
if reserve := a.GetWindowCostStickyReserve(); reserve > 0 {
out.WindowCostStickyReserve = &reserve
}
if maxSessions := a.GetMaxSessions(); maxSessions > 0 {
out.MaxSessions = &maxSessions
}
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
out.SessionIdleTimeoutMin = &idleTimeout
}
}
return out
} }
func AccountFromService(a *service.Account) *Account { func AccountFromService(a *service.Account) *Account {
......
...@@ -102,6 +102,16 @@ type Account struct { ...@@ -102,6 +102,16 @@ type Account struct {
SessionWindowEnd *time.Time `json:"session_window_end"` SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `json:"session_window_status"` SessionWindowStatus string `json:"session_window_status"`
// 5h窗口费用控制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
WindowCostLimit *float64 `json:"window_cost_limit,omitempty"`
WindowCostStickyReserve *float64 `json:"window_cost_sticky_reserve,omitempty"`
// 会话数量控制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
MaxSessions *int `json:"max_sessions,omitempty"`
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"` Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"`
......
...@@ -185,7 +185,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -185,7 +185,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0 lastFailoverStatus := 0
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
...@@ -320,7 +320,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -320,7 +320,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
......
...@@ -226,7 +226,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -226,7 +226,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
lastFailoverStatus := 0 lastFailoverStatus := 0
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
......
package repository
import (
"context"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 会话限制缓存常量定义
//
// 设计说明:
// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话:
// - Key: session_limit:account:{accountID}
// - Member: sessionUUID(从 metadata.user_id 中提取)
// - Score: Unix 时间戳(会话最后活跃时间)
//
// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
const (
// 会话限制键前缀
// 格式: session_limit:account:{accountID}
sessionLimitKeyPrefix = "session_limit:account:"
// 窗口费用缓存键前缀
// 格式: window_cost:account:{accountID}
windowCostKeyPrefix = "window_cost:account:"
// 窗口费用缓存 TTL(30秒)
windowCostCacheTTL = 30 * time.Second
)
var (
// registerSessionScript 注册会话活动
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = maxSessions
// ARGV[2] = idleTimeout(秒)
// ARGV[3] = sessionUUID
// 返回: 1 = 允许, 0 = 拒绝
registerSessionScript = redis.NewScript(`
local key = KEYS[1]
local maxSessions = tonumber(ARGV[1])
local idleTimeout = tonumber(ARGV[2])
local sessionUUID = ARGV[3]
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查会话是否已存在(支持刷新时间戳)
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
-- 会话已存在,刷新时间戳
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 检查是否达到会话数量上限
local count = redis.call('ZCARD', key)
if count < maxSessions then
-- 未达上限,添加新会话
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 达到上限,拒绝新会话
return 0
`)
// refreshSessionScript 刷新会话时间戳
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
refreshSessionScript = redis.NewScript(`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
-- 检查会话是否存在
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
end
return 1
`)
// getActiveSessionCountScript 获取活跃会话数
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
getActiveSessionCountScript = redis.NewScript(`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`)
// isSessionActiveScript 检查会话是否活跃
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout(秒)
// ARGV[2] = sessionUUID
isSessionActiveScript = redis.NewScript(`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 获取会话的时间戳
local score = redis.call('ZSCORE', key, sessionUUID)
if score == false then
return 0
end
-- 检查是否过期
if tonumber(score) <= expireBefore then
return 0
end
return 1
`)
)
type sessionLimitCache struct {
rdb *redis.Client
defaultIdleTimeout time.Duration // 默认空闲超时(用于 GetActiveSessionCount)
}
// NewSessionLimitCache 创建会话限制缓存
// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) service.SessionLimitCache {
if defaultIdleTimeoutMinutes <= 0 {
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
}
return &sessionLimitCache{
rdb: rdb,
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
}
}
// sessionLimitKey 生成会话限制的 Redis 键
func sessionLimitKey(accountID int64) string {
return fmt.Sprintf("%s%d", sessionLimitKeyPrefix, accountID)
}
// windowCostKey 生成窗口费用缓存的 Redis 键
func windowCostKey(accountID int64) string {
return fmt.Sprintf("%s%d", windowCostKeyPrefix, accountID)
}
// RegisterSession 注册会话活动
func (c *sessionLimitCache) RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (bool, error) {
if sessionUUID == "" || maxSessions <= 0 {
return true, nil // 无效参数,默认允许
}
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(idleTimeout.Seconds())
if idleTimeoutSeconds <= 0 {
idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
}
result, err := registerSessionScript.Run(ctx, c.rdb, []string{key}, maxSessions, idleTimeoutSeconds, sessionUUID).Int()
if err != nil {
return true, err // 失败开放:缓存错误时允许请求通过
}
return result == 1, nil
}
// RefreshSession 刷新会话时间戳
func (c *sessionLimitCache) RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error {
if sessionUUID == "" {
return nil
}
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(idleTimeout.Seconds())
if idleTimeoutSeconds <= 0 {
idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
}
_, err := refreshSessionScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Result()
return err
}
// GetActiveSessionCount 获取活跃会话数
func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) {
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
result, err := getActiveSessionCountScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds).Int()
if err != nil {
return 0, err
}
return result, nil
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return make(map[int64]int), nil
}
results := make(map[int64]int, len(accountIDs))
// 使用 pipeline 批量执行
pipe := c.rdb.Pipeline()
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
for _, accountID := range accountIDs {
key := sessionLimitKey(accountID)
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
}
// 执行 pipeline,即使部分失败也尝试获取成功的结果
_, _ = pipe.Exec(ctx)
for accountID, cmd := range cmds {
if result, err := cmd.Int(); err == nil {
results[accountID] = result
}
}
return results, nil
}
// IsSessionActive 检查会话是否活跃
func (c *sessionLimitCache) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) {
if sessionUUID == "" {
return false, nil
}
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
result, err := isSessionActiveScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
// ========== 5h窗口费用缓存实现 ==========
// GetWindowCost 获取缓存的窗口费用
func (c *sessionLimitCache) GetWindowCost(ctx context.Context, accountID int64) (float64, bool, error) {
key := windowCostKey(accountID)
val, err := c.rdb.Get(ctx, key).Float64()
if err == redis.Nil {
return 0, false, nil // 缓存未命中
}
if err != nil {
return 0, false, err
}
return val, true, nil
}
// SetWindowCost 设置窗口费用缓存
func (c *sessionLimitCache) SetWindowCost(ctx context.Context, accountID int64, cost float64) error {
key := windowCostKey(accountID)
return c.rdb.Set(ctx, key, cost, windowCostCacheTTL).Err()
}
// GetWindowCostBatch 批量获取窗口费用缓存
func (c *sessionLimitCache) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) {
if len(accountIDs) == 0 {
return make(map[int64]float64), nil
}
// 构建批量查询的 keys
keys := make([]string, len(accountIDs))
for i, accountID := range accountIDs {
keys[i] = windowCostKey(accountID)
}
// 使用 MGET 批量获取
vals, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return nil, err
}
results := make(map[int64]float64, len(accountIDs))
for i, val := range vals {
if val == nil {
continue // 缓存未命中
}
// 尝试解析为 float64
switch v := val.(type) {
case string:
if cost, err := strconv.ParseFloat(v, 64); err == nil {
results[accountIDs[i]] = cost
}
case float64:
results[accountIDs[i]] = v
}
}
return results, nil
}
...@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient ...@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
return NewPricingRemoteClient(cfg.Update.ProxyURL) return NewPricingRemoteClient(cfg.Update.ProxyURL)
} }
// ProvideSessionLimitCache 创建会话限制缓存
// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.SessionLimitCache {
defaultIdleTimeoutMinutes := 5 // 默认 5 分钟空闲超时
if cfg != nil && cfg.Gateway.SessionIdleTimeoutMinutes > 0 {
defaultIdleTimeoutMinutes = cfg.Gateway.SessionIdleTimeoutMinutes
}
return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes)
}
// ProviderSet is the Wire provider set for all repositories // ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewUserRepository, NewUserRepository,
...@@ -61,6 +71,7 @@ var ProviderSet = wire.NewSet( ...@@ -61,6 +71,7 @@ var ProviderSet = wire.NewSet(
NewTempUnschedCache, NewTempUnschedCache,
NewTimeoutCounterCache, NewTimeoutCounterCache,
ProvideConcurrencyCache, ProvideConcurrencyCache,
ProvideSessionLimitCache,
NewDashboardCache, NewDashboardCache,
NewEmailCache, NewEmailCache,
NewIdentityCache, NewIdentityCache,
......
...@@ -441,7 +441,7 @@ func newContractDeps(t *testing.T) *contractDeps { ...@@ -441,7 +441,7 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil) adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) { jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
......
...@@ -573,3 +573,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool { ...@@ -573,3 +573,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
} }
return false return false
} }
// WindowCostSchedulability 窗口费用调度状态
type WindowCostSchedulability int
const (
// WindowCostSchedulable 可正常调度
WindowCostSchedulable WindowCostSchedulability = iota
// WindowCostStickyOnly 仅允许粘性会话
WindowCostStickyOnly
// WindowCostNotSchedulable 完全不可调度
WindowCostNotSchedulable
)
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["window_cost_limit"]; ok {
return parseExtraFloat64(v)
}
return 0
}
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
// 默认值为 10
func (a *Account) GetWindowCostStickyReserve() float64 {
if a.Extra == nil {
return 10.0
}
if v, ok := a.Extra["window_cost_sticky_reserve"]; ok {
val := parseExtraFloat64(v)
if val > 0 {
return val
}
}
return 10.0
}
// GetMaxSessions 获取最大并发会话数
// 返回 0 表示未启用
func (a *Account) GetMaxSessions() int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["max_sessions"]; ok {
return parseExtraInt(v)
}
return 0
}
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
// 默认值为 5 分钟
func (a *Account) GetSessionIdleTimeoutMinutes() int {
if a.Extra == nil {
return 5
}
if v, ok := a.Extra["session_idle_timeout_minutes"]; ok {
val := parseExtraInt(v)
if val > 0 {
return val
}
}
return 5
}
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) WindowCostSchedulability {
limit := a.GetWindowCostLimit()
if limit <= 0 {
return WindowCostSchedulable
}
if currentWindowCost < limit {
return WindowCostSchedulable
}
stickyReserve := a.GetWindowCostStickyReserve()
if currentWindowCost < limit+stickyReserve {
return WindowCostStickyOnly
}
return WindowCostNotSchedulable
}
// parseExtraFloat64 从 extra 字段解析 float64 值
func parseExtraFloat64(value any) float64 {
switch v := value.(type) {
case float64:
return v
case float32:
return float64(v)
case int:
return float64(v)
case int64:
return float64(v)
case json.Number:
if f, err := v.Float64(); err == nil {
return f
}
case string:
if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil {
return f
}
}
return 0
}
// parseExtraInt 从 extra 字段解析 int 值
func parseExtraInt(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return 0
}
...@@ -575,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64 ...@@ -575,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
}, },
} }
} }
// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
// 用于账号列表页面显示当前窗口费用
func (s *AccountUsageService) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
return s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime)
}
...@@ -1052,7 +1052,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1052,7 +1052,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, // No concurrency service concurrencyService: nil, // No concurrency service
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1105,7 +1105,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1105,7 +1105,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, // legacy path concurrencyService: nil, // legacy path
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1137,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1137,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1169,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1169,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
} }
excludedIDs := map[int64]struct{}{1: {}} excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1203,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1203,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1239,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1239,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1266,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1266,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.Error(t, err) require.Error(t, err)
require.Nil(t, result) require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts") require.Contains(t, err.Error(), "no available accounts")
...@@ -1298,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1298,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1331,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1331,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
......
...@@ -213,6 +213,7 @@ type GatewayService struct { ...@@ -213,6 +213,7 @@ type GatewayService struct {
deferredService *DeferredService deferredService *DeferredService
concurrencyService *ConcurrencyService concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
} }
// NewGatewayService creates a new GatewayService // NewGatewayService creates a new GatewayService
...@@ -233,6 +234,7 @@ func NewGatewayService( ...@@ -233,6 +234,7 @@ func NewGatewayService(
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
deferredService *DeferredService, deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider, claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache,
) *GatewayService { ) *GatewayService {
return &GatewayService{ return &GatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -251,6 +253,7 @@ func NewGatewayService( ...@@ -251,6 +253,7 @@ func NewGatewayService(
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
deferredService: deferredService, deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider, claudeTokenProvider: claudeTokenProvider,
sessionLimitCache: sessionLimitCache,
} }
} }
...@@ -816,8 +819,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -816,8 +819,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
} }
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { // metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig() cfg := s.schedulingConfig()
// 提取会话 UUID(用于会话数量限制)
sessionUUID := extractSessionUUID(metadataUserID)
var stickyAccountID int64 var stickyAccountID int64
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
...@@ -936,7 +943,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -936,7 +943,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if len(routingAccountIDs) > 0 && s.concurrencyService != nil { if len(routingAccountIDs) > 0 && s.concurrencyService != nil {
// 1. 过滤出路由列表中可调度的账号 // 1. 过滤出路由列表中可调度的账号
var routingCandidates []*Account var routingCandidates []*Account
var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping int var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int
for _, routingAccountID := range routingAccountIDs { for _, routingAccountID := range routingAccountIDs {
if isExcluded(routingAccountID) { if isExcluded(routingAccountID) {
filteredExcluded++ filteredExcluded++
...@@ -963,13 +970,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -963,13 +970,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
filteredModelMapping++ filteredModelMapping++
continue continue
} }
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
filteredWindowCost++
continue
}
routingCandidates = append(routingCandidates, account) routingCandidates = append(routingCandidates, account)
} }
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d)", log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates), derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping) filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
} }
if len(routingCandidates) > 0 { if len(routingCandidates) > 0 {
...@@ -982,18 +994,25 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -982,18 +994,25 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if stickyAccount.IsSchedulable() && if stickyAccount.IsSchedulable() &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
stickyAccount.IsSchedulableForModel(requestedModel) && stickyAccount.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) { (requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) // 会话数量限制检查
if s.debugModelRoutingEnabled() { if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) result.ReleaseFunc() // 释放槽位
// 继续到负载感知选择
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
}
return &AccountSelectionResult{
Account: stickyAccount,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
return &AccountSelectionResult{
Account: stickyAccount,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
...@@ -1066,6 +1085,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1066,6 +1085,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, item := range routingAvailable { for _, item := range routingAvailable {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
} }
...@@ -1108,15 +1132,21 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1108,15 +1132,21 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if ok && s.isAccountInGroup(account, groupID) && if ok && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) && account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) // 会话数量限制检查
return &AccountSelectionResult{ if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
Account: account, result.ReleaseFunc() // 释放槽位,继续到 Layer 2
Acquired: true, } else {
ReleaseFunc: result.ReleaseFunc, _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
}, nil return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
} }
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
...@@ -1157,6 +1187,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1157,6 +1187,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue continue
} }
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
candidates = append(candidates, acc) candidates = append(candidates, acc)
} }
...@@ -1174,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1174,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil { if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
return result, nil return result, nil
} }
} else { } else {
...@@ -1223,6 +1257,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1223,6 +1257,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, item := range available { for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
} }
...@@ -1252,13 +1291,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1252,13 +1291,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts") return nil, errors.New("no available accounts")
} }
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...) ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
for _, acc := range ordered { for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
} }
...@@ -1490,6 +1534,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in ...@@ -1490,6 +1534,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
} }
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示可调度,false 表示不可调度
func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool {
// 只检查 Anthropic OAuth/SetupToken 账号
if !account.IsAnthropicOAuthOrSetupToken() {
return true
}
limit := account.GetWindowCostLimit()
if limit <= 0 {
return true // 未启用窗口费用限制
}
// 尝试从缓存获取窗口费用
var currentCost float64
if s.sessionLimitCache != nil {
if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit {
currentCost = cost
goto checkSchedulability
}
}
// 缓存未命中,从数据库查询
{
var startTime time.Time
if account.SessionWindowStart != nil {
startTime = *account.SessionWindowStart
} else {
startTime = time.Now().Add(-5 * time.Hour)
}
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil {
// 失败开放:查询失败时允许调度
return true
}
// 使用标准费用(不含账号倍率)
currentCost = stats.StandardCost
// 设置缓存(忽略错误)
if s.sessionLimitCache != nil {
_ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost)
}
}
checkSchedulability:
schedulability := account.CheckWindowCostSchedulability(currentCost)
switch schedulability {
case WindowCostSchedulable:
return true
case WindowCostStickyOnly:
return isSticky
case WindowCostNotSchedulable:
return false
}
return true
}
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
// 只检查 Anthropic OAuth/SetupToken 账号
if !account.IsAnthropicOAuthOrSetupToken() {
return true
}
maxSessions := account.GetMaxSessions()
if maxSessions <= 0 || sessionUUID == "" {
return true // 未启用会话限制或无会话ID
}
if s.sessionLimitCache == nil {
return true // 缓存不可用时允许通过
}
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
if err != nil {
// 失败开放:缓存错误时允许通过
return true
}
return allowed
}
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
// 格式: user_{64位hex}_account__session_{uuid}
func extractSessionUUID(metadataUserID string) string {
if metadataUserID == "" {
return ""
}
if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
return match[1]
}
return ""
}
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil { if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID) return s.schedulerSnapshot.GetAccount(ctx, accountID)
...@@ -2384,9 +2529,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2384,9 +2529,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart := time.Now() retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ { for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
// Capture upstream request body for ops retry of this attempt. // Capture upstream request body for ops retry of this attempt.
c.Set(OpsUpstreamRequestBodyKey, string(body)) c.Set(OpsUpstreamRequestBodyKey, string(body))
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -1067,15 +1067,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1067,15 +1067,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率 // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
lastDataAt := time.Now() lastDataAt := time.Now()
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) // 仅发送一次错误事件,避免多次写入导致协议混乱。
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
errorEventSent := false errorEventSent := false
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sendErrorEvent := func(reason string) { sendErrorEvent := func(reason string) {
if errorEventSent { if errorEventSent || clientDisconnected {
return return
} }
errorEventSent = true errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) payload := map[string]any{
flusher.Flush() "type": "error",
"sequence_number": 0,
"error": map[string]any{
"type": "upstream_error",
"message": reason,
"code": reason,
},
}
if b, err := json.Marshal(payload); err == nil {
_, _ = fmt.Fprintf(w, "data: %s\n\n", b)
flusher.Flush()
}
} }
needModelReplace := originalModel != mappedModel needModelReplace := originalModel != mappedModel
...@@ -1087,6 +1101,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1087,6 +1101,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
} }
if ev.err != nil { if ev.err != nil {
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
if clientDisconnected {
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if errors.Is(ev.err, bufio.ErrTooLong) { if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large") sendErrorEvent("response_too_large")
...@@ -1110,15 +1135,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1110,15 +1135,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// Correct Codex tool calls if needed (apply_patch -> edit, etc.) // Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
data = correctedData
line = "data: " + correctedData line = "data: " + correctedData
} }
// Forward line // 写入客户端(客户端断开后继续 drain 上游)
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if !clientDisconnected {
sendErrorEvent("write_failed") if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
} }
flusher.Flush()
// Record first token time // Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" { if firstTokenMs == nil && data != "" && data != "[DONE]" {
...@@ -1128,11 +1157,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1128,11 +1157,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
s.parseSSEUsage(data, usage) s.parseSSEUsage(data, usage)
} else { } else {
// Forward non-data lines as-is // Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { if !clientDisconnected {
sendErrorEvent("write_failed") if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
} }
flusher.Flush()
} }
case <-intervalCh: case <-intervalCh:
...@@ -1140,6 +1172,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1140,6 +1172,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if time.Since(lastRead) < streamInterval { if time.Since(lastRead) < streamInterval {
continue continue
} }
if clientDisconnected {
log.Printf("Upstream timeout after client disconnect, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态 // 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil { if s.rateLimitService != nil {
...@@ -1149,11 +1185,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ...@@ -1149,11 +1185,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh: case <-keepaliveCh:
if clientDisconnected {
continue
}
if time.Since(lastDataAt) < keepaliveInterval { if time.Since(lastDataAt) < keepaliveInterval {
continue continue
} }
if _, err := fmt.Fprint(w, ":\n\n"); err != nil { if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
continue
} }
flusher.Flush() flusher.Flush()
} }
......
...@@ -33,6 +33,25 @@ type stubConcurrencyCache struct { ...@@ -33,6 +33,25 @@ type stubConcurrencyCache struct {
ConcurrencyCache ConcurrencyCache
} }
type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
type failingGinWriter struct {
gin.ResponseWriter
failAfter int
writes int
}
func (w *failingGinWriter) Write(p []byte) (int, error) {
if w.writes >= w.failAfter {
return 0, errors.New("write failed")
}
w.writes++
return w.ResponseWriter.Write(p)
}
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil return true, nil
} }
...@@ -169,8 +188,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) { ...@@ -169,8 +188,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") { if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
t.Fatalf("expected stream timeout error, got %v", err) t.Fatalf("expected stream timeout error, got %v", err)
} }
if !strings.Contains(rec.Body.String(), "stream_timeout") { if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") {
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String()) t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ctx, cancel := context.WithCancel(context.Background())
cancel()
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: cancelReadCloser{},
Header: http.Header{},
}
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
}
}
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0}
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
}()
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if result == nil || result.usage == nil {
t.Fatalf("expected usage result")
}
if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 {
t.Fatalf("unexpected usage: %+v", *result.usage)
}
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") {
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
} }
} }
...@@ -209,8 +305,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) { ...@@ -209,8 +305,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
if !errors.Is(err, bufio.ErrTooLong) { if !errors.Is(err, bufio.ErrTooLong) {
t.Fatalf("expected ErrTooLong, got %v", err) t.Fatalf("expected ErrTooLong, got %v", err)
} }
if !strings.Contains(rec.Body.String(), "response_too_large") { if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") {
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String()) t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
} }
} }
......
...@@ -514,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry ...@@ -514,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
if s.gatewayService == nil { if s.gatewayService == nil {
return nil, fmt.Errorf("gateway service not available") return nil, fmt.Errorf("gateway service not available")
} }
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs) return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
default: default:
return nil, fmt.Errorf("unsupported retry type: %s", reqType) return nil, fmt.Errorf("unsupported retry type: %s", reqType)
} }
......
package service
import (
"context"
"time"
)
// SessionLimitCache 管理账号级别的活跃会话跟踪
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制
//
// Key 格式: session_limit:account:{accountID}
// 数据结构: Sorted Set (member=sessionUUID, score=timestamp)
//
// 会话在空闲超时后自动过期,无需手动清理
type SessionLimitCache interface {
// RegisterSession 注册会话活动
// - 如果会话已存在,刷新其时间戳并返回 true
// - 如果会话不存在且活跃会话数 < maxSessions,添加新会话并返回 true
// - 如果会话不存在且活跃会话数 >= maxSessions,返回 false(拒绝)
//
// 参数:
// accountID: 账号 ID
// sessionUUID: 从 metadata.user_id 中提取的会话 UUID
// maxSessions: 最大并发会话数限制
// idleTimeout: 会话空闲超时时间
//
// 返回:
// allowed: true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
// error: 操作错误
RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (allowed bool, err error)
// RefreshSession 刷新现有会话的时间戳
// 用于活跃会话保持活动状态
RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error
// GetActiveSessionCount 获取当前活跃会话数
// 返回未过期的会话数量
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// 返回 map[accountID]count,查询失败的账号不在 map 中
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
// IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
// ========== 5h窗口费用缓存 ==========
// Key 格式: window_cost:account:{accountID}
// 用于缓存账号在当前5h窗口内的标准费用,减少数据库聚合查询压力
// GetWindowCost 获取缓存的窗口费用
// 返回 (cost, true, nil) 如果缓存命中
// 返回 (0, false, nil) 如果缓存未命中
// 返回 (0, false, err) 如果发生错误
GetWindowCost(ctx context.Context, accountID int64) (cost float64, hit bool, err error)
// SetWindowCost 设置窗口费用缓存
SetWindowCost(ctx context.Context, accountID int64, cost float64) error
// GetWindowCostBatch 批量获取窗口费用缓存
// 返回 map[accountID]cost,缓存未命中的账号不在 map 中
GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error)
}
...@@ -33,6 +33,22 @@ ...@@ -33,6 +33,22 @@
# 修改为你的域名 # 修改为你的域名
example.com { example.com {
# =========================================================================
# 静态资源长期缓存(高优先级,放在最前面)
# 带 hash 的文件可以永久缓存,浏览器和 CDN 都会缓存
# =========================================================================
@static {
path /assets/*
path /logo.png
path /favicon.ico
}
header @static {
Cache-Control "public, max-age=31536000, immutable"
# 移除可能干扰缓存的头
-Pragma
-Expires
}
# ========================================================================= # =========================================================================
# TLS 安全配置 # TLS 安全配置
# ========================================================================= # =========================================================================
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment