Unverified Commit 1b79f6a7 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1522 from xvhuan/fix/redis-snapshot-meta-fix

优化调度快照缓存,避免 1.5 万账号场景下 Redis 大 MGET
parents 155d3474 265687b5
...@@ -100,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -100,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
} }
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig) dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
schedulerCache := repository.NewSchedulerCache(redisClient) schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache) accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
proxyRepository := repository.NewProxyRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
......
...@@ -620,6 +620,10 @@ type GatewaySchedulingConfig struct { ...@@ -620,6 +620,10 @@ type GatewaySchedulingConfig struct {
// 负载计算 // 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
// 快照桶读取时的 MGET 分块大小
SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"`
// 快照重建时的缓存写入分块大小
SnapshotWriteChunkSize int `mapstructure:"snapshot_write_chunk_size"`
// 过期槽位清理周期(0 表示禁用) // 过期槽位清理周期(0 表示禁用)
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"` SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
...@@ -1340,6 +1344,8 @@ func setDefaults() { ...@@ -1340,6 +1344,8 @@ func setDefaults() {
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used") viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
viper.SetDefault("gateway.scheduling.load_batch_enabled", true) viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128)
viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true) viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0) viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0)
...@@ -2001,6 +2007,12 @@ func (c *Config) Validate() error { ...@@ -2001,6 +2007,12 @@ func (c *Config) Validate() error {
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 { if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive") return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
} }
if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 {
return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive")
}
if c.Gateway.Scheduling.SnapshotWriteChunkSize <= 0 {
return fmt.Errorf("gateway.scheduling.snapshot_write_chunk_size must be positive")
}
if c.Gateway.Scheduling.SlotCleanupInterval < 0 { if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative") return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
} }
......
...@@ -34,7 +34,12 @@ func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerB ...@@ -34,7 +34,12 @@ func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerB
func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error { func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error {
return nil return nil
} }
func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) { func (f *fakeSchedulerCache) GetAccount(_ context.Context, id int64) (*service.Account, error) {
for _, account := range f.accounts {
if account != nil && account.ID == id {
return account, nil
}
}
return nil, nil return nil, nil
} }
func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil } func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil }
......
...@@ -332,6 +332,10 @@ func (h prefixHook) prefixCmd(cmd redisclient.Cmder) { ...@@ -332,6 +332,10 @@ func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists", "hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists",
"zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore": "zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore":
prefixOne(1) prefixOne(1)
case "mget":
for i := 1; i < len(args); i++ {
prefixOne(i)
}
case "del", "unlink": case "del", "unlink":
for i := 1; i < len(args); i++ { for i := 1; i < len(args); i++ {
prefixOne(i) prefixOne(i)
......
...@@ -15,19 +15,39 @@ const ( ...@@ -15,19 +15,39 @@ const (
schedulerBucketSetKey = "sched:buckets" schedulerBucketSetKey = "sched:buckets"
schedulerOutboxWatermarkKey = "sched:outbox:watermark" schedulerOutboxWatermarkKey = "sched:outbox:watermark"
schedulerAccountPrefix = "sched:acc:" schedulerAccountPrefix = "sched:acc:"
schedulerAccountMetaPrefix = "sched:meta:"
schedulerActivePrefix = "sched:active:" schedulerActivePrefix = "sched:active:"
schedulerReadyPrefix = "sched:ready:" schedulerReadyPrefix = "sched:ready:"
schedulerVersionPrefix = "sched:ver:" schedulerVersionPrefix = "sched:ver:"
schedulerSnapshotPrefix = "sched:" schedulerSnapshotPrefix = "sched:"
schedulerLockPrefix = "sched:lock:" schedulerLockPrefix = "sched:lock:"
defaultSchedulerSnapshotMGetChunkSize = 128
defaultSchedulerSnapshotWriteChunkSize = 256
) )
type schedulerCache struct { type schedulerCache struct {
rdb *redis.Client rdb *redis.Client
mgetChunkSize int
writeChunkSize int
} }
func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache { func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache {
return &schedulerCache{rdb: rdb} return newSchedulerCacheWithChunkSizes(rdb, defaultSchedulerSnapshotMGetChunkSize, defaultSchedulerSnapshotWriteChunkSize)
}
func newSchedulerCacheWithChunkSizes(rdb *redis.Client, mgetChunkSize, writeChunkSize int) service.SchedulerCache {
if mgetChunkSize <= 0 {
mgetChunkSize = defaultSchedulerSnapshotMGetChunkSize
}
if writeChunkSize <= 0 {
writeChunkSize = defaultSchedulerSnapshotWriteChunkSize
}
return &schedulerCache{
rdb: rdb,
mgetChunkSize: mgetChunkSize,
writeChunkSize: writeChunkSize,
}
} }
func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) { func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
...@@ -65,9 +85,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul ...@@ -65,9 +85,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
keys := make([]string, 0, len(ids)) keys := make([]string, 0, len(ids))
for _, id := range ids { for _, id := range ids {
keys = append(keys, schedulerAccountKey(id)) keys = append(keys, schedulerAccountMetaKey(id))
} }
values, err := c.rdb.MGet(ctx, keys...).Result() values, err := c.mgetChunked(ctx, keys)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
...@@ -100,14 +120,11 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul ...@@ -100,14 +120,11 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
versionStr := strconv.FormatInt(version, 10) versionStr := strconv.FormatInt(version, 10)
snapshotKey := schedulerSnapshotKey(bucket, versionStr) snapshotKey := schedulerSnapshotKey(bucket, versionStr)
pipe := c.rdb.Pipeline() if err := c.writeAccounts(ctx, accounts); err != nil {
for _, account := range accounts { return err
payload, err := json.Marshal(account)
if err != nil {
return err
}
pipe.Set(ctx, schedulerAccountKey(strconv.FormatInt(account.ID, 10)), payload, 0)
} }
pipe := c.rdb.Pipeline()
if len(accounts) > 0 { if len(accounts) > 0 {
// 使用序号作为 score,保持数据库返回的排序语义。 // 使用序号作为 score,保持数据库返回的排序语义。
members := make([]redis.Z, 0, len(accounts)) members := make([]redis.Z, 0, len(accounts))
...@@ -117,7 +134,13 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul ...@@ -117,7 +134,13 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
Member: strconv.FormatInt(account.ID, 10), Member: strconv.FormatInt(account.ID, 10),
}) })
} }
pipe.ZAdd(ctx, snapshotKey, members...) for start := 0; start < len(members); start += c.writeChunkSize {
end := start + c.writeChunkSize
if end > len(members) {
end = len(members)
}
pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
}
} else { } else {
pipe.Del(ctx, snapshotKey) pipe.Del(ctx, snapshotKey)
} }
...@@ -151,20 +174,15 @@ func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Accoun ...@@ -151,20 +174,15 @@ func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Accoun
if account == nil || account.ID <= 0 { if account == nil || account.ID <= 0 {
return nil return nil
} }
payload, err := json.Marshal(account) return c.writeAccounts(ctx, []service.Account{*account})
if err != nil {
return err
}
key := schedulerAccountKey(strconv.FormatInt(account.ID, 10))
return c.rdb.Set(ctx, key, payload, 0).Err()
} }
func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error { func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error {
if accountID <= 0 { if accountID <= 0 {
return nil return nil
} }
key := schedulerAccountKey(strconv.FormatInt(accountID, 10)) id := strconv.FormatInt(accountID, 10)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, schedulerAccountKey(id), schedulerAccountMetaKey(id)).Err()
} }
func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
...@@ -179,7 +197,7 @@ func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]t ...@@ -179,7 +197,7 @@ func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]t
ids = append(ids, id) ids = append(ids, id)
} }
values, err := c.rdb.MGet(ctx, keys...).Result() values, err := c.mgetChunked(ctx, keys)
if err != nil { if err != nil {
return err return err
} }
...@@ -198,7 +216,12 @@ func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]t ...@@ -198,7 +216,12 @@ func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]t
if err != nil { if err != nil {
return err return err
} }
metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(*account))
if err != nil {
return err
}
pipe.Set(ctx, keys[i], updated, 0) pipe.Set(ctx, keys[i], updated, 0)
pipe.Set(ctx, schedulerAccountMetaKey(strconv.FormatInt(ids[i], 10)), metaPayload, 0)
} }
_, err = pipe.Exec(ctx) _, err = pipe.Exec(ctx)
return err return err
...@@ -256,6 +279,10 @@ func schedulerAccountKey(id string) string { ...@@ -256,6 +279,10 @@ func schedulerAccountKey(id string) string {
return schedulerAccountPrefix + id return schedulerAccountPrefix + id
} }
func schedulerAccountMetaKey(id string) string {
return schedulerAccountMetaPrefix + id
}
func ptrTime(t time.Time) *time.Time { func ptrTime(t time.Time) *time.Time {
return &t return &t
} }
...@@ -276,3 +303,137 @@ func decodeCachedAccount(val any) (*service.Account, error) { ...@@ -276,3 +303,137 @@ func decodeCachedAccount(val any) (*service.Account, error) {
} }
return &account, nil return &account, nil
} }
func (c *schedulerCache) writeAccounts(ctx context.Context, accounts []service.Account) error {
if len(accounts) == 0 {
return nil
}
pipe := c.rdb.Pipeline()
pending := 0
flush := func() error {
if pending == 0 {
return nil
}
if _, err := pipe.Exec(ctx); err != nil {
return err
}
pipe = c.rdb.Pipeline()
pending = 0
return nil
}
for _, account := range accounts {
fullPayload, err := json.Marshal(account)
if err != nil {
return err
}
metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(account))
if err != nil {
return err
}
id := strconv.FormatInt(account.ID, 10)
pipe.Set(ctx, schedulerAccountKey(id), fullPayload, 0)
pipe.Set(ctx, schedulerAccountMetaKey(id), metaPayload, 0)
pending++
if pending >= c.writeChunkSize {
if err := flush(); err != nil {
return err
}
}
}
return flush()
}
func (c *schedulerCache) mgetChunked(ctx context.Context, keys []string) ([]any, error) {
if len(keys) == 0 {
return []any{}, nil
}
out := make([]any, 0, len(keys))
chunkSize := c.mgetChunkSize
if chunkSize <= 0 {
chunkSize = defaultSchedulerSnapshotMGetChunkSize
}
for start := 0; start < len(keys); start += chunkSize {
end := start + chunkSize
if end > len(keys) {
end = len(keys)
}
part, err := c.rdb.MGet(ctx, keys[start:end]...).Result()
if err != nil {
return nil, err
}
out = append(out, part...)
}
return out, nil
}
func buildSchedulerMetadataAccount(account service.Account) service.Account {
return service.Account{
ID: account.ID,
Name: account.Name,
Platform: account.Platform,
Type: account.Type,
Concurrency: account.Concurrency,
Priority: account.Priority,
RateMultiplier: account.RateMultiplier,
Status: account.Status,
LastUsedAt: account.LastUsedAt,
ExpiresAt: account.ExpiresAt,
AutoPauseOnExpired: account.AutoPauseOnExpired,
Schedulable: account.Schedulable,
RateLimitedAt: account.RateLimitedAt,
RateLimitResetAt: account.RateLimitResetAt,
OverloadUntil: account.OverloadUntil,
TempUnschedulableUntil: account.TempUnschedulableUntil,
TempUnschedulableReason: account.TempUnschedulableReason,
SessionWindowStart: account.SessionWindowStart,
SessionWindowEnd: account.SessionWindowEnd,
SessionWindowStatus: account.SessionWindowStatus,
Credentials: filterSchedulerCredentials(account.Credentials),
Extra: filterSchedulerExtra(account.Extra),
}
}
func filterSchedulerCredentials(credentials map[string]any) map[string]any {
if len(credentials) == 0 {
return nil
}
keys := []string{"model_mapping", "api_key", "project_id", "oauth_type"}
filtered := make(map[string]any)
for _, key := range keys {
if value, ok := credentials[key]; ok && value != nil {
filtered[key] = value
}
}
if len(filtered) == 0 {
return nil
}
return filtered
}
func filterSchedulerExtra(extra map[string]any) map[string]any {
if len(extra) == 0 {
return nil
}
keys := []string{
"mixed_scheduling",
"window_cost_limit",
"window_cost_sticky_reserve",
"max_sessions",
"session_idle_timeout_minutes",
}
filtered := make(map[string]any)
for _, key := range keys {
if value, ok := extra[key]; ok && value != nil {
filtered[key] = value
}
}
if len(filtered) == 0 {
return nil
}
return filtered
}
//go:build integration
package repository
import (
"context"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T) {
ctx := context.Background()
rdb := testRedis(t)
cache := NewSchedulerCache(rdb)
bucket := service.SchedulerBucket{GroupID: 2, Platform: service.PlatformGemini, Mode: service.SchedulerModeSingle}
now := time.Now().UTC().Truncate(time.Second)
limitReset := now.Add(10 * time.Minute)
overloadUntil := now.Add(2 * time.Minute)
tempUnschedUntil := now.Add(3 * time.Minute)
windowEnd := now.Add(5 * time.Hour)
account := service.Account{
ID: 101,
Name: "gemini-heavy",
Platform: service.PlatformGemini,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 3,
Priority: 7,
LastUsedAt: &now,
Credentials: map[string]any{
"api_key": "gemini-api-key",
"access_token": "secret-access-token",
"project_id": "proj-1",
"oauth_type": "ai_studio",
"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
"huge_blob": strings.Repeat("x", 4096),
},
Extra: map[string]any{
"mixed_scheduling": true,
"window_cost_limit": 12.5,
"window_cost_sticky_reserve": 8.0,
"max_sessions": 4,
"session_idle_timeout_minutes": 11,
"unused_large_field": strings.Repeat("y", 4096),
},
RateLimitResetAt: &limitReset,
OverloadUntil: &overloadUntil,
TempUnschedulableUntil: &tempUnschedUntil,
SessionWindowStart: &now,
SessionWindowEnd: &windowEnd,
SessionWindowStatus: "active",
}
require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account}))
snapshot, hit, err := cache.GetSnapshot(ctx, bucket)
require.NoError(t, err)
require.True(t, hit)
require.Len(t, snapshot, 1)
got := snapshot[0]
require.NotNil(t, got)
require.Equal(t, "gemini-api-key", got.GetCredential("api_key"))
require.Equal(t, "proj-1", got.GetCredential("project_id"))
require.Equal(t, "ai_studio", got.GetCredential("oauth_type"))
require.NotEmpty(t, got.GetModelMapping())
require.Empty(t, got.GetCredential("access_token"))
require.Empty(t, got.GetCredential("huge_blob"))
require.Equal(t, true, got.Extra["mixed_scheduling"])
require.Equal(t, 12.5, got.GetWindowCostLimit())
require.Equal(t, 8.0, got.GetWindowCostStickyReserve())
require.Equal(t, 4, got.GetMaxSessions())
require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes())
require.Nil(t, got.Extra["unused_large_field"])
full, err := cache.GetAccount(ctx, account.ID)
require.NoError(t, err)
require.NotNil(t, full)
require.Equal(t, "secret-access-token", full.GetCredential("access_token"))
require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob"))
}
...@@ -47,6 +47,21 @@ func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.Ses ...@@ -47,6 +47,21 @@ func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.Ses
return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes) return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes)
} }
// ProvideSchedulerCache 创建调度快照缓存,并注入快照分块参数。
func ProvideSchedulerCache(rdb *redis.Client, cfg *config.Config) service.SchedulerCache {
mgetChunkSize := defaultSchedulerSnapshotMGetChunkSize
writeChunkSize := defaultSchedulerSnapshotWriteChunkSize
if cfg != nil {
if cfg.Gateway.Scheduling.SnapshotMGetChunkSize > 0 {
mgetChunkSize = cfg.Gateway.Scheduling.SnapshotMGetChunkSize
}
if cfg.Gateway.Scheduling.SnapshotWriteChunkSize > 0 {
writeChunkSize = cfg.Gateway.Scheduling.SnapshotWriteChunkSize
}
}
return newSchedulerCacheWithChunkSizes(rdb, mgetChunkSize, writeChunkSize)
}
// ProviderSet is the Wire provider set for all repositories // ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet( var ProviderSet = wire.NewSet(
NewUserRepository, NewUserRepository,
...@@ -92,7 +107,7 @@ var ProviderSet = wire.NewSet( ...@@ -92,7 +107,7 @@ var ProviderSet = wire.NewSet(
NewRedeemCache, NewRedeemCache,
NewUpdateCache, NewUpdateCache,
NewGeminiTokenCache, NewGeminiTokenCache,
NewSchedulerCache, ProvideSchedulerCache,
NewSchedulerOutboxRepository, NewSchedulerOutboxRepository,
NewProxyLatencyCache, NewProxyLatencyCache,
NewTotpCache, NewTotpCache,
......
...@@ -1192,12 +1192,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -1192,12 +1192,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度 // 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err != nil {
return nil, err
}
return s.hydrateSelectedAccount(ctx, account)
} }
// antigravity 分组、强制平台模式或无分组使用单平台选择 // antigravity 分组、强制平台模式或无分组使用单平台选择
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err != nil {
return nil, err
}
return s.hydrateSelectedAccount(ctx, account)
} }
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
...@@ -1273,11 +1281,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1273,11 +1281,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
localExcluded[account.ID] = struct{}{} // 排除此账号 localExcluded[account.ID] = struct{}{} // 排除此账号
continue // 重新选择 continue // 重新选择
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
// 对于等待计划的情况,也需要先检查会话限制 // 对于等待计划的情况,也需要先检查会话限制
...@@ -1289,26 +1293,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1289,26 +1293,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: account.ID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: account.ID, Timeout: cfg.FallbackWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.FallbackMaxWaiting,
Timeout: cfg.FallbackWaitTimeout, })
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
} }
...@@ -1455,11 +1453,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1455,11 +1453,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil)
Account: stickyAccount,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
...@@ -1570,11 +1564,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1570,11 +1564,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil)
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
...@@ -1587,15 +1577,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1587,15 +1577,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{
Account: item.account, AccountID: item.account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: item.account.Concurrency,
AccountID: item.account.ID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: item.account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
// 所有路由账号会话限制都已满,继续到 Layer 2 回退 // 所有路由账号会话限制都已满,继续到 Layer 2 回退
} }
...@@ -1631,11 +1618,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1631,11 +1618,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, account, sessionHash) { if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2 result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else { } else {
return &AccountSelectionResult{ if s.cache != nil {
Account: account, _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
Acquired: true, }
ReleaseFunc: result.ReleaseFunc, return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
}, nil
} }
} }
...@@ -1647,15 +1633,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1647,15 +1633,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// 会话限制已满,继续到 Layer 2 // 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2 // Session limit full, continue to Layer 2
} else { } else {
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: accountID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: accountID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
} }
} }
...@@ -1714,7 +1697,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1714,7 +1697,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil { if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil {
return nil, legacyErr
} else if ok {
return result, nil return result, nil
} }
} else { } else {
...@@ -1753,11 +1738,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1753,11 +1738,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil)
Account: selected.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
...@@ -1780,20 +1761,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1780,20 +1761,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.checkAndRegisterSession(ctx, acc, sessionHash) { if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
continue // 会话限制已满,尝试下一个账号 continue // 会话限制已满,尝试下一个账号
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{
Account: acc, AccountID: acc.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: acc.Concurrency,
AccountID: acc.ID, Timeout: cfg.FallbackWaitTimeout,
MaxConcurrency: acc.Concurrency, MaxWaiting: cfg.FallbackMaxWaiting,
Timeout: cfg.FallbackWaitTimeout, })
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
return nil, ErrNoAvailableAccounts return nil, ErrNoAvailableAccounts
} }
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) {
ordered := append([]*Account(nil), candidates...) ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
...@@ -1808,15 +1786,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates ...@@ -1808,15 +1786,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
} }
return &AccountSelectionResult{ selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil)
Account: acc, if err != nil {
Acquired: true, return nil, false, err
ReleaseFunc: result.ReleaseFunc, }
}, true return selection, true, nil
} }
} }
return nil, false return nil, false, nil
} }
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
...@@ -2431,6 +2409,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in ...@@ -2431,6 +2409,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
return s.accountRepo.GetByID(ctx, accountID) return s.accountRepo.GetByID(ctx, accountID)
} }
func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) {
if account == nil || s.schedulerSnapshot == nil {
return account, nil
}
hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID)
if err != nil {
return nil, err
}
if hydrated == nil {
return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID)
}
return hydrated, nil
}
func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) {
hydrated, err := s.hydrateSelectedAccount(ctx, account)
if err != nil {
return nil, err
}
return &AccountSelectionResult{
Account: hydrated,
Acquired: acquired,
ReleaseFunc: release,
WaitPlan: waitPlan,
}, nil
}
// filterByMinPriority 过滤出优先级最小的账号集合 // filterByMinPriority 过滤出优先级最小的账号集合
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
if len(accounts) == 0 { if len(accounts) == 0 {
......
...@@ -137,7 +137,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co ...@@ -137,7 +137,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
} }
return selected, nil return s.hydrateSelectedAccount(ctx, selected)
} }
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。 // resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
...@@ -416,6 +416,20 @@ func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context, ...@@ -416,6 +416,20 @@ func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context,
return s.accountRepo.GetByID(ctx, accountID) return s.accountRepo.GetByID(ctx, accountID)
} }
func (s *GeminiMessagesCompatService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) {
if account == nil || s.schedulerSnapshot == nil {
return account, nil
}
hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID)
if err != nil {
return nil, err
}
if hydrated == nil {
return nil, fmt.Errorf("selected gemini account %d not found during hydration", account.ID)
}
return hydrated, nil
}
func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) { func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) {
if s.schedulerSnapshot != nil { if s.schedulerSnapshot != nil {
accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
...@@ -546,7 +560,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont ...@@ -546,7 +560,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
if selected == nil { if selected == nil {
return nil, errors.New("no available Gemini accounts") return nil, errors.New("no available Gemini accounts")
} }
return selected, nil return s.hydrateSelectedAccount(ctx, selected)
} }
func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
......
...@@ -1243,7 +1243,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C ...@@ -1243,7 +1243,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL)
} }
return selected, nil return s.hydrateSelectedAccount(ctx, selected)
} }
// tryStickySessionHit 尝试从粘性会话获取账号。 // tryStickySessionHit 尝试从粘性会话获取账号。
...@@ -1408,35 +1408,25 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1408,35 +1408,25 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
} }
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: account.ID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: account.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: account.ID, Timeout: cfg.FallbackWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.FallbackMaxWaiting,
Timeout: cfg.FallbackWaitTimeout, })
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
accounts, err := s.listSchedulableAccounts(ctx, groupID) accounts, err := s.listSchedulableAccounts(ctx, groupID)
...@@ -1476,24 +1466,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1476,24 +1466,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
Account: account, AccountID: accountID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: account.Concurrency,
AccountID: accountID, Timeout: cfg.StickySessionWaitTimeout,
MaxConcurrency: account.Concurrency, MaxWaiting: cfg.StickySessionMaxWaiting,
Timeout: cfg.StickySessionWaitTimeout, })
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
} }
} }
} }
...@@ -1552,11 +1535,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1552,11 +1535,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if sessionHash != "" { if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
} else { } else {
...@@ -1609,11 +1588,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1609,11 +1588,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if sessionHash != "" { if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
} }
} }
} }
...@@ -1629,15 +1604,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1629,15 +1604,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue continue
} }
return &AccountSelectionResult{ return s.newSelectionResult(ctx, fresh, false, nil, &AccountWaitPlan{
Account: fresh, AccountID: fresh.ID,
WaitPlan: &AccountWaitPlan{ MaxConcurrency: fresh.Concurrency,
AccountID: fresh.ID, Timeout: cfg.FallbackWaitTimeout,
MaxConcurrency: fresh.Concurrency, MaxWaiting: cfg.FallbackMaxWaiting,
Timeout: cfg.FallbackWaitTimeout, })
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
} }
return nil, ErrNoAvailableAccounts return nil, ErrNoAvailableAccounts
...@@ -1732,6 +1704,33 @@ func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accoun ...@@ -1732,6 +1704,33 @@ func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accoun
return account, nil return account, nil
} }
func (s *OpenAIGatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) {
if account == nil || s.schedulerSnapshot == nil {
return account, nil
}
hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID)
if err != nil {
return nil, err
}
if hydrated == nil {
return nil, fmt.Errorf("selected openai account %d not found during hydration", account.ID)
}
return hydrated, nil
}
func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) {
hydrated, err := s.hydrateSelectedAccount(ctx, account)
if err != nil {
return nil, err
}
return &AccountSelectionResult{
Account: hydrated,
Acquired: acquired,
ReleaseFunc: release,
WaitPlan: waitPlan,
}, nil
}
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil { if s.cfg != nil {
return s.cfg.Gateway.Scheduling return s.cfg.Gateway.Scheduling
......
//go:build unit
package service
import (
"context"
"testing"
"time"
)
type snapshotHydrationCache struct {
snapshot []*Account
accounts map[int64]*Account
}
func (c *snapshotHydrationCache) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
return c.snapshot, true, nil
}
func (c *snapshotHydrationCache) SetSnapshot(ctx context.Context, bucket SchedulerBucket, accounts []Account) error {
return nil
}
func (c *snapshotHydrationCache) GetAccount(ctx context.Context, accountID int64) (*Account, error) {
if c.accounts == nil {
return nil, nil
}
return c.accounts[accountID], nil
}
func (c *snapshotHydrationCache) SetAccount(ctx context.Context, account *Account) error {
return nil
}
func (c *snapshotHydrationCache) DeleteAccount(ctx context.Context, accountID int64) error {
return nil
}
func (c *snapshotHydrationCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return nil
}
func (c *snapshotHydrationCache) TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error) {
return true, nil
}
func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) {
return nil, nil
}
func (c *snapshotHydrationCache) GetOutboxWatermark(ctx context.Context) (int64, error) {
return 0, nil
}
func (c *snapshotHydrationCache) SetOutboxWatermark(ctx context.Context, id int64) error {
return nil
}
func TestOpenAISelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) {
cache := &snapshotHydrationCache{
snapshot: []*Account{
{
ID: 1,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-4": "gpt-4",
},
},
},
},
accounts: map[int64]*Account{
1: {
ID: 1,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
Credentials: map[string]any{
"api_key": "sk-live",
"model_mapping": map[string]any{"gpt-4": "gpt-4"},
},
},
},
}
schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, nil, nil, nil)
groupID := int64(2)
svc := &OpenAIGatewayService{
schedulerSnapshot: schedulerSnapshot,
cache: &stubGatewayCache{},
}
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if selection == nil || selection.Account == nil {
t.Fatalf("expected selected account")
}
if got := selection.Account.GetOpenAIApiKey(); got != "sk-live" {
t.Fatalf("expected hydrated api key, got %q", got)
}
}
func TestGatewaySelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) {
cache := &snapshotHydrationCache{
snapshot: []*Account{
{
ID: 9,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
},
},
accounts: map[int64]*Account{
9: {
ID: 9,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 1,
Credentials: map[string]any{
"api_key": "anthropic-live-key",
},
},
},
}
schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, nil, nil, nil)
svc := &GatewayService{
schedulerSnapshot: schedulerSnapshot,
cache: &mockGatewayCacheForPlatform{},
cfg: testConfig(),
}
result, err := svc.SelectAccountWithLoadAwareness(context.Background(), nil, "", "claude-3-5-sonnet-20241022", nil, "", 0)
if err != nil {
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
}
if result == nil || result.Account == nil {
t.Fatalf("expected selected account")
}
if got := result.Account.GetCredential("api_key"); got != "anthropic-live-key" {
t.Fatalf("expected hydrated api key, got %q", got)
}
}
...@@ -347,6 +347,12 @@ gateway: ...@@ -347,6 +347,12 @@ gateway:
# Enable batch load calculation for scheduling # Enable batch load calculation for scheduling
# 启用调度批量负载计算 # 启用调度批量负载计算
load_batch_enabled: true load_batch_enabled: true
# Snapshot bucket MGET chunk size
# 调度快照分桶读取时的 MGET 分块大小
snapshot_mget_chunk_size: 128
# Snapshot bucket write chunk size
# 调度快照重建写入时的分块大小
snapshot_write_chunk_size: 256
# Slot cleanup interval (duration) # Slot cleanup interval (duration)
# 并发槽位清理周期(时间段) # 并发槽位清理周期(时间段)
slot_cleanup_interval: 30s slot_cleanup_interval: 30s
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment