Commit a88698f3 authored by erio's avatar erio
Browse files

feat: cleanup stale concurrency slots on startup



When the service restarts, concurrency slots from the old process
remain in Redis, causing phantom occupancy. On startup, scan all
concurrency sorted sets and remove members with non-current process
prefix, then clear orphaned wait queue counters.

Uses Go-side SCAN to discover keys (compatible with Redis client
prefix hooks in tests), then passes them to a Lua script for
atomic member-level cleanup.
Co-Authored-By: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent c8eff343
...@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc ...@@ -127,6 +127,7 @@ func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acc
return result, nil return result, nil
} }
func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil } func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil }
func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil }
func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) { func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) {
t.Helper() t.Helper()
......
...@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a ...@@ -89,6 +89,10 @@ func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, a
return nil return nil
} }
func (m *concurrencyCacheMock) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
return nil
}
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) { func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
cache := &concurrencyCacheMock{ cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
......
...@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont ...@@ -120,6 +120,10 @@ func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Cont
return nil return nil
} }
func (s *helperConcurrencyCacheStub) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
return nil
}
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) { func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
......
...@@ -147,17 +147,47 @@ var ( ...@@ -147,17 +147,47 @@ var (
return 1 return 1
`) `)
// cleanupExpiredSlotsScript - remove expired slots // cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位
// KEYS[1] = concurrency:account:{accountID} // KEYS[1] = 有序集合键
// ARGV[1] = TTL (seconds) // ARGV[1] = TTL(秒)
cleanupExpiredSlotsScript = redis.NewScript(` cleanupExpiredSlotsScript = redis.NewScript(`
local key = KEYS[1] local key = KEYS[1]
local ttl = tonumber(ARGV[1]) local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME') local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1]) local now = tonumber(timeResult[1])
local expireBefore = now - ttl local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`) if redis.call('ZCARD', key) == 0 then
redis.call('DEL', key)
else
redis.call('EXPIRE', key, ttl)
end
return 1
`)
// startupCleanupScript 清理非当前进程前缀的槽位成员。
// KEYS 是有序集合键列表,ARGV[1] 是当前进程前缀,ARGV[2] 是槽位 TTL。
// 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key,否则刷新 EXPIRE。
startupCleanupScript = redis.NewScript(`
local activePrefix = ARGV[1]
local slotTTL = tonumber(ARGV[2])
local removed = 0
for i = 1, #KEYS do
local key = KEYS[i]
local members = redis.call('ZRANGE', key, 0, -1)
for _, member in ipairs(members) do
if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then
removed = removed + redis.call('ZREM', key, member)
end
end
if redis.call('ZCARD', key) == 0 then
redis.call('DEL', key)
else
redis.call('EXPIRE', key, slotTTL)
end
end
return removed
`)
) )
type concurrencyCache struct { type concurrencyCache struct {
...@@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou ...@@ -463,3 +493,72 @@ func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accou
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result() _, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
return err return err
} }
func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
if activeRequestPrefix == "" {
return nil
}
// 1. 清理有序集合中非当前进程前缀的成员
slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"}
for _, pattern := range slotPatterns {
if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil {
return err
}
}
// 2. 删除所有等待队列计数器(重启后计数器失效)
waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"}
for _, pattern := range waitPatterns {
if err := c.deleteKeysByPattern(ctx, pattern); err != nil {
return err
}
}
return nil
}
// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。
func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error {
const scanCount = 200
var cursor uint64
for {
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
if err != nil {
return fmt.Errorf("scan %s: %w", pattern, err)
}
if len(keys) > 0 {
_, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result()
if err != nil {
return fmt.Errorf("cleanup slots %s: %w", pattern, err)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}
// deleteKeysByPattern 扫描匹配 pattern 的键并删除。
func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error {
const scanCount = 200
var cursor uint64
for {
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
if err != nil {
return fmt.Errorf("scan %s: %w", pattern, err)
}
if len(keys) > 0 {
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
return fmt.Errorf("del %s: %w", pattern, err)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}
...@@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct { ...@@ -25,6 +25,10 @@ type ConcurrencyCacheSuite struct {
cache service.ConcurrencyCache cache service.ConcurrencyCache
} }
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}
func (s *ConcurrencyCacheSuite) SetupTest() { func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest() s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds())) s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
...@@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() { ...@@ -247,17 +251,41 @@ func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
require.Equal(s.T(), 1, val, "expected account wait count 1") require.Equal(s.T(), 1, val, "expected account wait count 1")
} }
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() { func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
accountID := int64(301) accountID := int64(901)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) userID := int64(902)
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key") now := time.Now().Unix()
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
redis.Z{Score: float64(now), Member: "oldproc-1"},
redis.Z{Score: float64(now), Member: "keep-1"},
).Err())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
redis.Z{Score: float64(now), Member: "oldproc-2"},
redis.Z{Score: float64(now), Member: "keep-2"},
).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"keep-1"}, accountMembers)
val, err := s.rdb.Get(s.ctx, waitKey).Int() userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
if !errors.Is(err, redis.Nil) { require.NoError(s.T(), err)
require.NoError(s.T(), err, "Get waitKey") require.Equal(s.T(), []string{"keep-2"}, userMembers)
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty") _, err = s.rdb.Get(s.ctx, userWaitKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil))
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil))
} }
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
...@@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() { ...@@ -407,6 +435,53 @@ func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
require.Equal(s.T(), 2, cur) require.Equal(s.T(), 2, cur)
} }
func TestConcurrencyCacheSuite(t *testing.T) { func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
suite.Run(t, new(ConcurrencyCacheSuite)) accountID := int64(901)
userID := int64(902)
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
now := float64(time.Now().Unix())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
redis.Z{Score: now, Member: "oldproc-1"},
redis.Z{Score: now, Member: "activeproc-1"},
).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
redis.Z{Score: now, Member: "oldproc-2"},
redis.Z{Score: now, Member: "activeproc-2"},
).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
require.ErrorIs(s.T(), err, redis.Nil)
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
require.ErrorIs(s.T(), err, redis.Nil)
}
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
accountID := int64(903)
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
require.NoError(s.T(), err)
require.EqualValues(s.T(), 0, exists)
} }
...@@ -43,6 +43,9 @@ type ConcurrencyCache interface { ...@@ -43,6 +43,9 @@ type ConcurrencyCache interface {
// 清理过期槽位(后台任务) // 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
// 启动时清理旧进程遗留槽位与等待计数
CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error
} }
var ( var (
...@@ -59,13 +62,22 @@ func initRequestIDPrefix() string { ...@@ -59,13 +62,22 @@ func initRequestIDPrefix() string {
return "r" + strconv.FormatUint(fallback, 36) return "r" + strconv.FormatUint(fallback, 36)
} }
// generateRequestID generates a unique request ID for concurrency slot tracking. func RequestIDPrefix() string {
// Format: {process_random_prefix}-{base36_counter} return requestIDPrefix
}
func generateRequestID() string { func generateRequestID() string {
seq := requestIDCounter.Add(1) seq := requestIDCounter.Add(1)
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36) return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
} }
func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error {
if s == nil || s.cache == nil {
return nil
}
return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix())
}
const ( const (
// Default extra wait slots beyond concurrency limit // Default extra wait slots beyond concurrency limit
defaultExtraWaitSlots = 20 defaultExtraWaitSlots = 20
......
...@@ -91,6 +91,32 @@ func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Conte ...@@ -91,6 +91,32 @@ func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Conte
return c.cleanupErr return c.cleanupErr
} }
func (c *stubConcurrencyCacheForTest) CleanupStaleProcessSlots(_ context.Context, _ string) error {
return c.cleanupErr
}
type trackingConcurrencyCache struct {
stubConcurrencyCacheForTest
cleanupPrefix string
}
func (c *trackingConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, prefix string) error {
c.cleanupPrefix = prefix
return c.cleanupErr
}
func TestCleanupStaleProcessSlots_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
}
func TestCleanupStaleProcessSlots_DelegatesPrefix(t *testing.T) {
cache := &trackingConcurrencyCache{}
svc := NewConcurrencyService(cache)
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
require.Equal(t, RequestIDPrefix(), cache.cleanupPrefix)
}
func TestAcquireAccountSlot_Success(t *testing.T) { func TestAcquireAccountSlot_Success(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireResult: true} cache := &stubConcurrencyCacheForTest{acquireResult: true}
svc := NewConcurrencyService(cache) svc := NewConcurrencyService(cache)
......
...@@ -1986,6 +1986,10 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a ...@@ -1986,6 +1986,10 @@ func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, a
return nil return nil
} }
func (m *mockConcurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
return nil
}
func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { func (m *mockConcurrencyCache) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
result := make(map[int64]*UserLoadInfo, len(users)) result := make(map[int64]*UserLoadInfo, len(users))
for _, user := range users { for _, user := range users {
......
...@@ -105,6 +105,9 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh ...@@ -105,6 +105,9 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker. // ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService { func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
svc := NewConcurrencyService(cache) svc := NewConcurrencyService(cache)
if err := svc.CleanupStaleProcessSlots(context.Background()); err != nil {
logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err)
}
if cfg != nil { if cfg != nil {
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval) svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
} }
......
...@@ -76,6 +76,9 @@ func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acco ...@@ -76,6 +76,9 @@ func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, acco
func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
return nil return nil
} }
func (c StubConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, _ string) error {
return nil
}
// ============================================================ // ============================================================
// StubGatewayCache — service.GatewayCache 的空实现 // StubGatewayCache — service.GatewayCache 的空实现
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment