package repository import ( "context" "encoding/json" "fmt" "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" ) const ( schedulerBucketSetKey = "sched:buckets" schedulerOutboxWatermarkKey = "sched:outbox:watermark" schedulerAccountPrefix = "sched:acc:" schedulerAccountMetaPrefix = "sched:meta:" schedulerActivePrefix = "sched:active:" schedulerReadyPrefix = "sched:ready:" schedulerVersionPrefix = "sched:ver:" schedulerSnapshotPrefix = "sched:" schedulerLockPrefix = "sched:lock:" defaultSchedulerSnapshotMGetChunkSize = 128 defaultSchedulerSnapshotWriteChunkSize = 256 ) type schedulerCache struct { rdb *redis.Client mgetChunkSize int writeChunkSize int } func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache { return newSchedulerCacheWithChunkSizes(rdb, defaultSchedulerSnapshotMGetChunkSize, defaultSchedulerSnapshotWriteChunkSize) } func newSchedulerCacheWithChunkSizes(rdb *redis.Client, mgetChunkSize, writeChunkSize int) service.SchedulerCache { if mgetChunkSize <= 0 { mgetChunkSize = defaultSchedulerSnapshotMGetChunkSize } if writeChunkSize <= 0 { writeChunkSize = defaultSchedulerSnapshotWriteChunkSize } return &schedulerCache{ rdb: rdb, mgetChunkSize: mgetChunkSize, writeChunkSize: writeChunkSize, } } func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) { readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket) readyVal, err := c.rdb.Get(ctx, readyKey).Result() if err == redis.Nil { return nil, false, nil } if err != nil { return nil, false, err } if readyVal != "1" { return nil, false, nil } activeKey := schedulerBucketKey(schedulerActivePrefix, bucket) activeVal, err := c.rdb.Get(ctx, activeKey).Result() if err == redis.Nil { return nil, false, nil } if err != nil { return nil, false, err } snapshotKey := schedulerSnapshotKey(bucket, activeVal) ids, err := c.rdb.ZRange(ctx, snapshotKey, 0, -1).Result() if err != nil { return nil, false, err } if len(ids) == 0 { // 空快照视为缓存未命中,触发数据库回退查询 // 这解决了新分组创建后立即绑定账号时的竞态条件问题 return nil, false, nil } keys := make([]string, 0, len(ids)) for _, id := range ids { keys = append(keys, schedulerAccountMetaKey(id)) } values, err := c.mgetChunked(ctx, keys) if err != nil { return nil, false, err } accounts := make([]*service.Account, 0, len(values)) for _, val := range values { if val == nil { return nil, false, nil } account, err := decodeCachedAccount(val) if err != nil { return nil, false, err } accounts = append(accounts, account) } return accounts, true, nil } func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error { activeKey := schedulerBucketKey(schedulerActivePrefix, bucket) oldActive, _ := c.rdb.Get(ctx, activeKey).Result() versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket) version, err := c.rdb.Incr(ctx, versionKey).Result() if err != nil { return err } versionStr := strconv.FormatInt(version, 10) snapshotKey := schedulerSnapshotKey(bucket, versionStr) if err := c.writeAccounts(ctx, accounts); err != nil { return err } pipe := c.rdb.Pipeline() if len(accounts) > 0 { // 使用序号作为 score,保持数据库返回的排序语义。 members := make([]redis.Z, 0, len(accounts)) for idx, account := range accounts { members = append(members, redis.Z{ Score: float64(idx), Member: strconv.FormatInt(account.ID, 10), }) } for start := 0; start < len(members); start += c.writeChunkSize { end := start + c.writeChunkSize if end > len(members) { end = len(members) } pipe.ZAdd(ctx, snapshotKey, members[start:end]...) } } else { pipe.Del(ctx, snapshotKey) } pipe.Set(ctx, activeKey, versionStr, 0) pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0) pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String()) if _, err := pipe.Exec(ctx); err != nil { return err } if oldActive != "" && oldActive != versionStr { _ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err() } return nil } func (c *schedulerCache) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) { key := schedulerAccountKey(strconv.FormatInt(accountID, 10)) val, err := c.rdb.Get(ctx, key).Result() if err == redis.Nil { return nil, nil } if err != nil { return nil, err } return decodeCachedAccount(val) } func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Account) error { if account == nil || account.ID <= 0 { return nil } return c.writeAccounts(ctx, []service.Account{*account}) } func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error { if accountID <= 0 { return nil } id := strconv.FormatInt(accountID, 10) return c.rdb.Del(ctx, schedulerAccountKey(id), schedulerAccountMetaKey(id)).Err() } func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { if len(updates) == 0 { return nil } keys := make([]string, 0, len(updates)) ids := make([]int64, 0, len(updates)) for id := range updates { keys = append(keys, schedulerAccountKey(strconv.FormatInt(id, 10))) ids = append(ids, id) } values, err := c.mgetChunked(ctx, keys) if err != nil { return err } pipe := c.rdb.Pipeline() for i, val := range values { if val == nil { continue } account, err := decodeCachedAccount(val) if err != nil { return err } account.LastUsedAt = ptrTime(updates[ids[i]]) updated, err := json.Marshal(account) if err != nil { return err } metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(*account)) if err != nil { return err } pipe.Set(ctx, keys[i], updated, 0) pipe.Set(ctx, schedulerAccountMetaKey(strconv.FormatInt(ids[i], 10)), metaPayload, 0) } _, err = pipe.Exec(ctx) return err } func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) { key := schedulerBucketKey(schedulerLockPrefix, bucket) return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result() } func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) { raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result() if err != nil { return nil, err } out := make([]service.SchedulerBucket, 0, len(raw)) for _, entry := range raw { bucket, ok := service.ParseSchedulerBucket(entry) if !ok { continue } out = append(out, bucket) } return out, nil } func (c *schedulerCache) GetOutboxWatermark(ctx context.Context) (int64, error) { val, err := c.rdb.Get(ctx, schedulerOutboxWatermarkKey).Result() if err == redis.Nil { return 0, nil } if err != nil { return 0, err } id, err := strconv.ParseInt(val, 10, 64) if err != nil { return 0, err } return id, nil } func (c *schedulerCache) SetOutboxWatermark(ctx context.Context, id int64) error { return c.rdb.Set(ctx, schedulerOutboxWatermarkKey, strconv.FormatInt(id, 10), 0).Err() } func schedulerBucketKey(prefix string, bucket service.SchedulerBucket) string { return fmt.Sprintf("%s%d:%s:%s", prefix, bucket.GroupID, bucket.Platform, bucket.Mode) } func schedulerSnapshotKey(bucket service.SchedulerBucket, version string) string { return fmt.Sprintf("%s%d:%s:%s:v%s", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode, version) } func schedulerAccountKey(id string) string { return schedulerAccountPrefix + id } func schedulerAccountMetaKey(id string) string { return schedulerAccountMetaPrefix + id } func ptrTime(t time.Time) *time.Time { return &t } func decodeCachedAccount(val any) (*service.Account, error) { var payload []byte switch raw := val.(type) { case string: payload = []byte(raw) case []byte: payload = raw default: return nil, fmt.Errorf("unexpected account cache type: %T", val) } var account service.Account if err := json.Unmarshal(payload, &account); err != nil { return nil, err } return &account, nil } func (c *schedulerCache) writeAccounts(ctx context.Context, accounts []service.Account) error { if len(accounts) == 0 { return nil } pipe := c.rdb.Pipeline() pending := 0 flush := func() error { if pending == 0 { return nil } if _, err := pipe.Exec(ctx); err != nil { return err } pipe = c.rdb.Pipeline() pending = 0 return nil } for _, account := range accounts { fullPayload, err := json.Marshal(account) if err != nil { return err } metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(account)) if err != nil { return err } id := strconv.FormatInt(account.ID, 10) pipe.Set(ctx, schedulerAccountKey(id), fullPayload, 0) pipe.Set(ctx, schedulerAccountMetaKey(id), metaPayload, 0) pending++ if pending >= c.writeChunkSize { if err := flush(); err != nil { return err } } } return flush() } func (c *schedulerCache) mgetChunked(ctx context.Context, keys []string) ([]any, error) { if len(keys) == 0 { return []any{}, nil } out := make([]any, 0, len(keys)) chunkSize := c.mgetChunkSize if chunkSize <= 0 { chunkSize = defaultSchedulerSnapshotMGetChunkSize } for start := 0; start < len(keys); start += chunkSize { end := start + chunkSize if end > len(keys) { end = len(keys) } part, err := c.rdb.MGet(ctx, keys[start:end]...).Result() if err != nil { return nil, err } out = append(out, part...) } return out, nil } func buildSchedulerMetadataAccount(account service.Account) service.Account { return service.Account{ ID: account.ID, Name: account.Name, Platform: account.Platform, Type: account.Type, Concurrency: account.Concurrency, Priority: account.Priority, RateMultiplier: account.RateMultiplier, Status: account.Status, LastUsedAt: account.LastUsedAt, ExpiresAt: account.ExpiresAt, AutoPauseOnExpired: account.AutoPauseOnExpired, Schedulable: account.Schedulable, RateLimitedAt: account.RateLimitedAt, RateLimitResetAt: account.RateLimitResetAt, OverloadUntil: account.OverloadUntil, TempUnschedulableUntil: account.TempUnschedulableUntil, TempUnschedulableReason: account.TempUnschedulableReason, SessionWindowStart: account.SessionWindowStart, SessionWindowEnd: account.SessionWindowEnd, SessionWindowStatus: account.SessionWindowStatus, Credentials: filterSchedulerCredentials(account.Credentials), Extra: filterSchedulerExtra(account.Extra), } } func filterSchedulerCredentials(credentials map[string]any) map[string]any { if len(credentials) == 0 { return nil } keys := []string{"model_mapping", "api_key", "project_id", "oauth_type"} filtered := make(map[string]any) for _, key := range keys { if value, ok := credentials[key]; ok && value != nil { filtered[key] = value } } if len(filtered) == 0 { return nil } return filtered } func filterSchedulerExtra(extra map[string]any) map[string]any { if len(extra) == 0 { return nil } keys := []string{ "mixed_scheduling", "window_cost_limit", "window_cost_sticky_reserve", "max_sessions", "session_idle_timeout_minutes", } filtered := make(map[string]any) for _, key := range keys { if value, ok := extra[key]; ok && value != nil { filtered[key] = value } } if len(filtered) == 0 { return nil } return filtered }