"backend/internal/handler/vscode:/vscode.git/clone" did not exist on "0c7cbe356639665f9f84675518f18efd9c215de5"
Commit 7efa8b54 authored by yangjianbo's avatar yangjianbo
Browse files

perf(后端): 完成性能优化与连接池配置

新增 DB/Redis 连接池配置与校验,并补充单测

网关请求体大小限制与 413 处理

HTTP/req 客户端池化并调整上游连接池默认值

并发槽位改为 ZSET+Lua 与指数退避

用量统计改 SQL 聚合并新增索引迁移

计费缓存写入改工作池并补测试/基准

测试: 在 backend/ 下运行 go test ./...
parent 53767866
......@@ -67,6 +67,7 @@ func provideCleanup(
tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
......@@ -94,6 +95,10 @@ func provideCleanup(
emailQueue.Stop()
return nil
}},
{"BillingCacheService", func() error {
billingCache.Stop()
return nil
}},
{"OAuthService", func() error {
oauth.Stop()
return nil
......
......@@ -39,11 +39,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
if err != nil {
return nil, err
}
sqlDB, err := infrastructure.ProvideSQLDB(client)
db, err := infrastructure.ProvideSQLDB(client)
if err != nil {
return nil, err
}
userRepository := repository.NewUserRepository(client, sqlDB)
userRepository := repository.NewUserRepository(client, db)
settingRepository := repository.NewSettingRepository(client)
settingService := service.NewSettingService(settingRepository, configConfig)
redisClient := infrastructure.ProvideRedis(configConfig)
......@@ -57,12 +57,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
authHandler := handler.NewAuthHandler(configConfig, authService, userService)
userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, sqlDB)
groupRepository := repository.NewGroupRepository(client, db)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
apiKeyCache := repository.NewApiKeyCache(redisClient)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, sqlDB)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
......@@ -75,8 +75,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardService := service.NewDashboardService(usageLogRepository)
dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(client, sqlDB)
proxyRepository := repository.NewProxyRepository(client, sqlDB)
accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminUserHandler := admin.NewUserHandler(adminService)
......@@ -95,7 +95,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, httpUpstream)
concurrencyCache := repository.NewConcurrencyCache(redisClient)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
......@@ -142,7 +142,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
httpServer := server.ProvideHTTPServer(configConfig, engine)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
antigravityQuotaRefresher := service.ProvideAntigravityQuotaRefresher(accountRepository, proxyRepository, antigravityOAuthService, configConfig)
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher)
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, antigravityQuotaRefresher)
application := &Application{
Server: httpServer,
Cleanup: v,
......@@ -170,6 +170,7 @@ func provideCleanup(
tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService,
......@@ -196,6 +197,10 @@ func provideCleanup(
emailQueue.Stop()
return nil
}},
{"BillingCacheService", func() error {
billingCache.Stop()
return nil
}},
{"OAuthService", func() error {
oauth.Stop()
return nil
......
......@@ -79,12 +79,29 @@ type GatewayConfig struct {
// 等待上游响应头的超时时间(秒),0表示无超时
// 注意:这不影响流式数据传输,只控制等待响应头的时间
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"`
// 请求体最大字节数,用于网关请求体大小限制
MaxBodySize int64 `mapstructure:"max_body_size"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
MaxIdleConns int `mapstructure:"max_idle_conns"`
// MaxIdleConnsPerHost: 每个主机的最大空闲连接数(关键参数,影响连接复用率)
MaxIdleConnsPerHost int `mapstructure:"max_idle_conns_per_host"`
// MaxConnsPerHost: 每个主机的最大连接数(包括活跃+空闲),0表示无限制
MaxConnsPerHost int `mapstructure:"max_conns_per_host"`
// IdleConnTimeoutSeconds: 空闲连接超时时间(秒)
IdleConnTimeoutSeconds int `mapstructure:"idle_conn_timeout_seconds"`
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
}
func (s *ServerConfig) Address() string {
return fmt.Sprintf("%s:%d", s.Host, s.Port)
}
// DatabaseConfig 数据库连接配置
// 性能优化:新增连接池参数,避免频繁创建/销毁连接
type DatabaseConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
......@@ -92,6 +109,15 @@ type DatabaseConfig struct {
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"`
// 连接池配置(性能优化:可配置化连接池参数)
// MaxOpenConns: 最大打开连接数,控制数据库连接上限,防止资源耗尽
MaxOpenConns int `mapstructure:"max_open_conns"`
// MaxIdleConns: 最大空闲连接数,保持热连接减少建连延迟
MaxIdleConns int `mapstructure:"max_idle_conns"`
// ConnMaxLifetimeMinutes: 连接最大存活时间,防止长连接导致的资源泄漏
ConnMaxLifetimeMinutes int `mapstructure:"conn_max_lifetime_minutes"`
// ConnMaxIdleTimeMinutes: 空闲连接最大存活时间,及时释放不活跃连接
ConnMaxIdleTimeMinutes int `mapstructure:"conn_max_idle_time_minutes"`
}
func (d *DatabaseConfig) DSN() string {
......@@ -112,11 +138,24 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
)
}
// RedisConfig Redis 连接配置
// 性能优化:新增连接池和超时参数,提升高并发场景下的吞吐量
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
// 连接池与超时配置(性能优化:可配置化连接池参数)
// DialTimeoutSeconds: 建立连接超时,防止慢连接阻塞
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
// ReadTimeoutSeconds: 读取超时,避免慢查询阻塞连接池
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
// WriteTimeoutSeconds: 写入超时,避免慢写入阻塞连接池
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
// PoolSize: 连接池大小,控制最大并发连接数
PoolSize int `mapstructure:"pool_size"`
// MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟
MinIdleConns int `mapstructure:"min_idle_conns"`
}
func (r *RedisConfig) Address() string {
......@@ -203,12 +242,21 @@ func setDefaults() {
viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "disable")
viper.SetDefault("database.max_open_conns", 50)
viper.SetDefault("database.max_idle_conns", 10)
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
// Redis
viper.SetDefault("redis.host", "localhost")
viper.SetDefault("redis.port", 6379)
viper.SetDefault("redis.password", "")
viper.SetDefault("redis.db", 0)
viper.SetDefault("redis.dial_timeout_seconds", 5)
viper.SetDefault("redis.read_timeout_seconds", 3)
viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 128)
viper.SetDefault("redis.min_idle_conns", 10)
// JWT
viper.SetDefault("jwt.secret", "change-me-in-production")
......@@ -240,6 +288,13 @@ func setDefaults() {
// Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认)
viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认)
viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认)
viper.SetDefault("gateway.idle_conn_timeout_seconds", 300) // 空闲连接超时(秒)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
......@@ -263,6 +318,57 @@ func (c *Config) Validate() error {
if c.JWT.Secret == "change-me-in-production" && c.Server.Mode == "release" {
return fmt.Errorf("jwt.secret must be changed in production")
}
if c.Database.MaxOpenConns <= 0 {
return fmt.Errorf("database.max_open_conns must be positive")
}
if c.Database.MaxIdleConns < 0 {
return fmt.Errorf("database.max_idle_conns must be non-negative")
}
if c.Database.MaxIdleConns > c.Database.MaxOpenConns {
return fmt.Errorf("database.max_idle_conns cannot exceed database.max_open_conns")
}
if c.Database.ConnMaxLifetimeMinutes < 0 {
return fmt.Errorf("database.conn_max_lifetime_minutes must be non-negative")
}
if c.Database.ConnMaxIdleTimeMinutes < 0 {
return fmt.Errorf("database.conn_max_idle_time_minutes must be non-negative")
}
if c.Redis.DialTimeoutSeconds <= 0 {
return fmt.Errorf("redis.dial_timeout_seconds must be positive")
}
if c.Redis.ReadTimeoutSeconds <= 0 {
return fmt.Errorf("redis.read_timeout_seconds must be positive")
}
if c.Redis.WriteTimeoutSeconds <= 0 {
return fmt.Errorf("redis.write_timeout_seconds must be positive")
}
if c.Redis.PoolSize <= 0 {
return fmt.Errorf("redis.pool_size must be positive")
}
if c.Redis.MinIdleConns < 0 {
return fmt.Errorf("redis.min_idle_conns must be non-negative")
}
if c.Redis.MinIdleConns > c.Redis.PoolSize {
return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size")
}
if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive")
}
if c.Gateway.MaxIdleConns <= 0 {
return fmt.Errorf("gateway.max_idle_conns must be positive")
}
if c.Gateway.MaxIdleConnsPerHost <= 0 {
return fmt.Errorf("gateway.max_idle_conns_per_host must be positive")
}
if c.Gateway.MaxConnsPerHost < 0 {
return fmt.Errorf("gateway.max_conns_per_host must be non-negative")
}
if c.Gateway.IdleConnTimeoutSeconds <= 0 {
return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive")
}
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
}
return nil
}
......
......@@ -67,6 +67,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
......@@ -76,15 +80,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 解析请求获取模型名和stream
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
parsedReq, err := service.ParseGatewayRequest(body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
// Track if we've started streaming (for error handling)
streamStarted := false
......@@ -106,7 +108,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted)
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
......@@ -124,7 +126,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
platform := ""
......@@ -141,7 +143,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0
for {
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
......@@ -153,16 +155,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, req.Model)
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
......@@ -172,7 +174,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, req.Model, "generateContent", req.Stream, body)
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
} else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
}
......@@ -223,7 +225,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model, failedAccountIDs)
account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
......@@ -235,16 +237,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if req.Stream {
sendMockWarmupStream(c, req.Model)
if reqStream {
sendMockWarmupStream(c, reqModel)
} else {
sendMockWarmupResponse(c, req.Model)
sendMockWarmupResponse(c, reqModel)
}
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
......@@ -256,7 +258,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, body)
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
......@@ -496,6 +498,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
......@@ -505,11 +511,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 解析请求获取模型名
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err != nil {
parsedReq, err := service.ParseGatewayRequest(body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
......@@ -525,17 +528,17 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
// 计算粘性会话 hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
// 转发请求(不记录使用量)
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, body); err != nil {
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
log.Printf("Forward count_tokens request failed: %v", err)
// 错误响应已在 ForwardCountTokens 中处理
return
......
......@@ -3,6 +3,7 @@ package handler
import (
"context"
"fmt"
"math/rand"
"net/http"
"time"
......@@ -11,11 +12,28 @@ import (
"github.com/gin-gonic/gin"
)
// 并发槽位等待相关常量
//
// 性能优化说明:
// 原实现使用固定间隔(100ms)轮询并发槽位,存在以下问题:
// 1. 高并发时频繁轮询增加 Redis 压力
// 2. 固定间隔可能导致多个请求同时重试(惊群效应)
//
// 新实现使用指数退避 + 抖动算法:
// 1. 初始退避 100ms,每次乘以 1.5,最大 2s
// 2. 添加 ±20% 的随机抖动,分散重试时间点
// 3. 减少 Redis 压力,避免惊群效应
const (
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
// maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second
// pingInterval is the interval for sending ping events during slot wait
// pingInterval 流式响应等待时发送 ping 的间隔
pingInterval = 15 * time.Second
// initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
backoffMultiplier = 1.5
// maxBackoff 最大退避时间
maxBackoff = 2 * time.Second
)
// SSEPingFormat defines the format of SSE ping events for different platforms
......@@ -131,8 +149,10 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
pingCh = pingTicker.C
}
pollTicker := time.NewTicker(100 * time.Millisecond)
defer pollTicker.Stop()
backoff := initialBackoff
timer := time.NewTimer(backoff)
defer timer.Stop()
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for {
select {
......@@ -156,7 +176,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
}
flusher.Flush()
case <-pollTicker.C:
case <-timer.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
......@@ -174,6 +194,35 @@ func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string,
if result.Acquired {
return result.ReleaseFunc, nil
}
backoff = nextBackoff(backoff, rng)
timer.Reset(backoff)
}
}
}
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
// rng: 随机数生成器(可为 nil,此时不添加抖动)
// 返回值:下一次退避时间(100ms ~ 2s 之间)
func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration {
// 指数退避:当前时间 * 1.5
next := time.Duration(float64(current) * backoffMultiplier)
if next > maxBackoff {
next = maxBackoff
}
if rng == nil {
return next
}
// 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2)
// 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis
jitter := 0.8 + rng.Float64()*0.4
jittered := time.Duration(float64(next) * jitter)
if jittered < initialBackoff {
return initialBackoff
}
if jittered > maxBackoff {
return maxBackoff
}
return jittered
}
......@@ -148,6 +148,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
return
}
googleError(c, http.StatusBadRequest, "Failed to read request body")
return
}
......@@ -191,7 +195,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 3) select account (sticky session based on request body)
sessionHash := h.gatewayService.GenerateSessionHash(body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
......
......@@ -56,6 +56,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Read request body
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
......
package handler
import (
"errors"
"fmt"
"net/http"
)
func extractMaxBytesError(err error) (*http.MaxBytesError, bool) {
var maxErr *http.MaxBytesError
if errors.As(err, &maxErr) {
return maxErr, true
}
return nil, false
}
func formatBodyLimit(limit int64) string {
const mb = 1024 * 1024
if limit >= mb {
return fmt.Sprintf("%dMB", limit/mb)
}
return fmt.Sprintf("%dB", limit)
}
func buildBodyTooLargeMessage(limit int64) string {
return fmt.Sprintf("Request body too large, limit is %s", formatBodyLimit(limit))
}
package handler
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestRequestBodyLimitTooLarge(t *testing.T) {
gin.SetMode(gin.TestMode)
limit := int64(16)
router := gin.New()
router.Use(middleware.RequestBodyLimit(limit))
router.POST("/test", func(c *gin.Context) {
_, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
c.JSON(http.StatusRequestEntityTooLarge, gin.H{
"error": buildBodyTooLargeMessage(maxErr.Limit),
})
return
}
c.JSON(http.StatusBadRequest, gin.H{
"error": "read_failed",
})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
payload := bytes.Repeat([]byte("a"), int(limit+1))
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(payload))
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
require.Contains(t, recorder.Body.String(), buildBodyTooLargeMessage(limit))
}
package infrastructure
import (
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
type dbPoolSettings struct {
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
return dbPoolSettings{
MaxOpenConns: cfg.Database.MaxOpenConns,
MaxIdleConns: cfg.Database.MaxIdleConns,
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
}
}
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
settings := buildDBPoolSettings(cfg)
db.SetMaxOpenConns(settings.MaxOpenConns)
db.SetMaxIdleConns(settings.MaxIdleConns)
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
}
package infrastructure
import (
"database/sql"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
_ "github.com/lib/pq"
)
func TestBuildDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 50,
MaxIdleConns: 10,
ConnMaxLifetimeMinutes: 30,
ConnMaxIdleTimeMinutes: 5,
},
}
settings := buildDBPoolSettings(cfg)
require.Equal(t, 50, settings.MaxOpenConns)
require.Equal(t, 10, settings.MaxIdleConns)
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
}
func TestApplyDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 40,
MaxIdleConns: 8,
ConnMaxLifetimeMinutes: 15,
ConnMaxIdleTimeMinutes: 3,
},
}
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
applyDBPoolSettings(db, cfg)
stats := db.Stats()
require.Equal(t, 40, stats.MaxOpenConnections)
}
......@@ -51,6 +51,7 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
if err != nil {
return nil, nil, err
}
applyDBPoolSettings(drv.DB(), cfg)
// 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源(source of truth)。
......
package infrastructure
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/redis/go-redis/v9"
)
// InitRedis 初始化 Redis 客户端
//
// 性能优化说明:
// 原实现使用 go-redis 默认配置,未设置连接池和超时参数:
// 1. 默认连接池大小可能不足以支撑高并发
// 2. 无超时控制可能导致慢操作阻塞
//
// 新实现支持可配置的连接池和超时参数:
// 1. PoolSize: 控制最大并发连接数(默认 128)
// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10)
// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时
func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
return redis.NewClient(buildRedisOptions(cfg))
}
// buildRedisOptions 构建 Redis 连接选项
// 从配置文件读取连接池和超时参数,支持生产环境调优
func buildRedisOptions(cfg *config.Config) *redis.Options {
return &redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时
ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时
WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时
PoolSize: cfg.Redis.PoolSize, // 连接池大小
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
}
}
package infrastructure
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestBuildRedisOptions(t *testing.T) {
cfg := &config.Config{
Redis: config.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "secret",
DB: 2,
DialTimeoutSeconds: 5,
ReadTimeoutSeconds: 3,
WriteTimeoutSeconds: 4,
PoolSize: 100,
MinIdleConns: 10,
},
}
opts := buildRedisOptions(cfg)
require.Equal(t, "localhost:6379", opts.Addr)
require.Equal(t, "secret", opts.Password)
require.Equal(t, 2, opts.DB)
require.Equal(t, 5*time.Second, opts.DialTimeout)
require.Equal(t, 3*time.Second, opts.ReadTimeout)
require.Equal(t, 4*time.Second, opts.WriteTimeout)
require.Equal(t, 100, opts.PoolSize)
require.Equal(t, 10, opts.MinIdleConns)
}
// Package httpclient 提供共享 HTTP 客户端池
//
// 性能优化说明:
// 原实现在多个服务中重复创建 http.Client:
// 1. proxy_probe_service.go: 每次探测创建新客户端
// 2. pricing_service.go: 每次请求创建新客户端
// 3. turnstile_service.go: 每次验证创建新客户端
// 4. github_release_service.go: 每次请求创建新客户端
// 5. claude_usage_service.go: 每次请求创建新客户端
//
// 新实现使用统一的客户端池:
// 1. 相同配置复用同一 http.Client 实例
// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销
// 3. 支持 HTTP/HTTPS/SOCKS5 代理
// 4. 支持严格代理模式(代理失败则返回错误)
package httpclient
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/net/proxy"
)
// Transport 连接池默认配置
const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
)
// Options 定义共享 HTTP 客户端的构建参数
type Options struct {
ProxyURL string // 代理 URL(支持 http/https/socks5)
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
// 可选的连接池参数(不设置则使用默认值)
MaxIdleConns int // 最大空闲连接总数(默认 100)
MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10)
MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制)
}
// sharedClients 存储按配置参数缓存的 http.Client 实例
var sharedClients sync.Map
// GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
func GetClient(opts Options) (*http.Client, error) {
key := buildClientKey(opts)
if cached, ok := sharedClients.Load(key); ok {
return cached.(*http.Client), nil
}
client, err := buildClient(opts)
if err != nil {
if opts.ProxyStrict {
return nil, err
}
fallback := opts
fallback.ProxyURL = ""
client, _ = buildClient(fallback)
}
actual, _ := sharedClients.LoadOrStore(key, client)
return actual.(*http.Client), nil
}
func buildClient(opts Options) (*http.Client, error) {
transport, err := buildTransport(opts)
if err != nil {
return nil, err
}
return &http.Client{
Transport: transport,
Timeout: opts.Timeout,
}, nil
}
func buildTransport(opts Options) (*http.Transport, error) {
// 使用自定义值或默认值
maxIdleConns := opts.MaxIdleConns
if maxIdleConns <= 0 {
maxIdleConns = defaultMaxIdleConns
}
maxIdleConnsPerHost := opts.MaxIdleConnsPerHost
if maxIdleConnsPerHost <= 0 {
maxIdleConnsPerHost = defaultMaxIdleConnsPerHost
}
transport := &http.Transport{
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制
IdleConnTimeout: defaultIdleConnTimeout,
ResponseHeaderTimeout: opts.ResponseHeaderTimeout,
}
if opts.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
proxyURL := strings.TrimSpace(opts.ProxyURL)
if proxyURL == "" {
return transport, nil
}
parsed, err := url.Parse(proxyURL)
if err != nil {
return nil, err
}
switch strings.ToLower(parsed.Scheme) {
case "http", "https":
transport.Proxy = http.ProxyURL(parsed)
case "socks5", "socks5h":
dialer, err := proxy.FromURL(parsed, proxy.Direct)
if err != nil {
return nil, err
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme)
}
return transport, nil
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
)
}
......@@ -233,15 +233,11 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
}
func createReqClient(proxyURL string) *req.Client {
client := req.C().
ImpersonateChrome().
SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
Impersonate: true,
})
}
func prefix(s string, n int) string {
......
......@@ -6,9 +6,9 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
......@@ -23,20 +23,12 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
}
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, fmt.Errorf("failed to get default transport")
}
transport = transport.Clone()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
......
......@@ -3,67 +3,90 @@ package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 并发控制缓存常量定义
//
// 性能优化说明:
// 原实现使用 SCAN 命令遍历独立的槽位键(concurrency:account:{id}:{requestID}),
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
//
// 新实现改用 Redis 有序集合(Sorted Set):
// 1. 每个账号/用户只有一个键,成员为 requestID,分数为时间戳
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
// 4. 单次 Redis 调用完成计数,减少网络往返
const (
// Key prefixes for independent slot keys
// Format: concurrency:account:{accountID}:{requestID}
// 并发槽位键前缀(有序集合)
// 格式: concurrency:account:{accountID}
accountSlotKeyPrefix = "concurrency:account:"
// Format: concurrency:user:{userID}:{requestID}
// 格式: concurrency:user:{userID}
userSlotKeyPrefix = "concurrency:user:"
// Wait queue keeps counter format: concurrency:wait:{userID}
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// Slot TTL - each slot expires independently
slotTTL = 5 * time.Minute
// 默认槽位过期时间(分钟),可通过配置覆盖
defaultSlotTTLMinutes = 15
)
var (
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
// acquireScript 使用有序集合计数并在未达上限时添加槽位
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
// ARGV[1] = maxConcurrency
// ARGV[2] = TTL in seconds
// ARGV[2] = TTL(秒)
// ARGV[3] = requestID
acquireScript = redis.NewScript(`
local pattern = KEYS[1]
local slotKey = KEYS[2]
local key = KEYS[1]
local maxConcurrency = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local requestID = ARGV[3]
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- Count existing slots using SCAN
local cursor = "0"
local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
-- 清理过期槽位
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查是否已存在(支持重试场景刷新时间戳)
local exists = redis.call('ZSCORE', key, requestID)
if exists ~= false then
redis.call('ZADD', key, now, requestID)
redis.call('EXPIRE', key, ttl)
return 1
end
-- Check if we can acquire a slot
-- 检查是否达到并发上限
local count = redis.call('ZCARD', key)
if count < maxConcurrency then
redis.call('SET', slotKey, '1', 'EX', ttl)
redis.call('ZADD', key, now, requestID)
redis.call('EXPIRE', key, ttl)
return 1
end
return 0
`)
// getCountScript counts slots using SCAN
// KEYS[1] = pattern for SCAN
// getCountScript 统计有序集合中的槽位数量并清理过期条目
// 使用 Redis TIME 命令获取服务器时间
// KEYS[1] = 有序集合键
// ARGV[1] = TTL(秒)
getCountScript = redis.NewScript(`
local pattern = KEYS[1]
local cursor = "0"
local count = 0
repeat
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
cursor = result[1]
count = count + #result[2]
until cursor == "0"
return count
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
-- 使用 Redis 服务器时间
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
......@@ -103,28 +126,29 @@ var (
)
type concurrencyCache struct {
rdb *redis.Client
rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒)
}
func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache {
return &concurrencyCache{rdb: rdb}
// NewConcurrencyCache 创建并发控制缓存
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache {
if slotTTLMinutes <= 0 {
slotTTLMinutes = defaultSlotTTLMinutes
}
return &concurrencyCache{
rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60,
}
}
// Helper functions for key generation
func accountSlotKey(accountID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
}
func accountSlotPattern(accountID int64) string {
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
func accountSlotKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
}
func userSlotKey(userID int64, requestID string) string {
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
}
func userSlotPattern(userID int64) string {
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
func userSlotKey(userID int64) string {
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
}
func waitQueueKey(userID int64) string {
......@@ -134,10 +158,9 @@ func waitQueueKey(userID int64) string {
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := accountSlotPattern(accountID)
slotKey := accountSlotKey(accountID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
key := accountSlotKey(accountID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
if err != nil {
return false, err
}
......@@ -145,13 +168,14 @@ func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int
}
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
slotKey := accountSlotKey(accountID, requestID)
return c.rdb.Del(ctx, slotKey).Err()
key := accountSlotKey(accountID)
return c.rdb.ZRem(ctx, key, requestID).Err()
}
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
pattern := accountSlotPattern(accountID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
key := accountSlotKey(accountID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
if err != nil {
return 0, err
}
......@@ -161,10 +185,9 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
// User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
pattern := userSlotPattern(userID)
slotKey := userSlotKey(userID, requestID)
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
key := userSlotKey(userID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
if err != nil {
return false, err
}
......@@ -172,13 +195,14 @@ func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, ma
}
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
slotKey := userSlotKey(userID, requestID)
return c.rdb.Del(ctx, slotKey).Err()
key := userSlotKey(userID)
return c.rdb.ZRem(ctx, key, requestID).Err()
}
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
pattern := userSlotPattern(userID)
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
key := userSlotKey(userID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
if err != nil {
return 0, err
}
......@@ -189,7 +213,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
if err != nil {
return false, err
}
......
package repository
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/redis/go-redis/v9"
)
// 基准测试用 TTL 配置
const benchSlotTTLMinutes = 15
var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
func BenchmarkAccountConcurrency(b *testing.B) {
rdb := newBenchmarkRedisClient(b)
defer func() {
_ = rdb.Close()
}()
cache := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache)
ctx := context.Background()
for _, size := range []int{10, 100, 1000} {
size := size
b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) {
accountID := time.Now().UnixNano()
key := accountSlotKey(accountID)
b.StopTimer()
members := make([]redis.Z, 0, size)
now := float64(time.Now().Unix())
for i := 0; i < size; i++ {
members = append(members, redis.Z{
Score: now,
Member: fmt.Sprintf("req_%d", i),
})
}
if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil {
b.Fatalf("初始化有序集合失败: %v", err)
}
if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil {
b.Fatalf("设置有序集合 TTL 失败: %v", err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil {
b.Fatalf("获取并发数量失败: %v", err)
}
}
b.StopTimer()
if err := rdb.Del(ctx, key).Err(); err != nil {
b.Fatalf("清理有序集合失败: %v", err)
}
})
b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) {
accountID := time.Now().UnixNano()
pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
keys := make([]string, 0, size)
b.StopTimer()
pipe := rdb.Pipeline()
for i := 0; i < size; i++ {
key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i)
keys = append(keys, key)
pipe.Set(ctx, key, "1", benchSlotTTL)
}
if _, err := pipe.Exec(ctx); err != nil {
b.Fatalf("初始化扫描键失败: %v", err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := scanSlotCount(ctx, rdb, pattern); err != nil {
b.Fatalf("SCAN 计数失败: %v", err)
}
}
b.StopTimer()
if err := rdb.Del(ctx, keys...).Err(); err != nil {
b.Fatalf("清理扫描键失败: %v", err)
}
})
}
}
func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) {
var cursor uint64
count := 0
for {
keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
return 0, err
}
count += len(keys)
if nextCursor == 0 {
break
}
cursor = nextCursor
}
return count, nil
}
func newBenchmarkRedisClient(b *testing.B) *redis.Client {
b.Helper()
redisURL := os.Getenv("TEST_REDIS_URL")
if redisURL == "" {
b.Skip("未设置 TEST_REDIS_URL,跳过 Redis 基准测试")
}
opt, err := redis.ParseURL(redisURL)
if err != nil {
b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err)
}
client := redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
b.Fatalf("Redis 连接失败: %v", err)
}
return client
}
......@@ -14,6 +14,12 @@ import (
"github.com/stretchr/testify/suite"
)
// 测试用 TTL 配置(15 分钟,与默认值一致)
const testSlotTTLMinutes = 15
// 测试用 TTL Duration,用于 TTL 断言
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
type ConcurrencyCacheSuite struct {
IntegrationRedisSuite
cache service.ConcurrencyCache
......@@ -21,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb)
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
......@@ -54,7 +60,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
accountID := int64(11)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
require.NoError(s.T(), err, "AcquireAccountSlot")
......@@ -62,7 +68,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
......@@ -139,7 +145,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
userID := int64(200)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID)
slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
require.NoError(s.T(), err, "AcquireUserSlot")
......@@ -147,7 +153,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
......@@ -168,7 +174,7 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment