From 265687b56d876bd8b672c7b3adf70f9676a19719 Mon Sep 17 00:00:00 2001 From: ius Date: Wed, 8 Apr 2026 10:10:15 -0700 Subject: [PATCH 01/52] =?UTF-8?q?fix:=20=E4=BC=98=E5=8C=96=E8=B0=83?= =?UTF-8?q?=E5=BA=A6=E5=BF=AB=E7=85=A7=E7=BC=93=E5=AD=98=E4=BB=A5=E9=81=BF?= =?UTF-8?q?=E5=85=8D=20Redis=20=E5=A4=A7=20MGET?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/cmd/server/wire_gen.go | 2 +- backend/internal/config/config.go | 12 ++ ...eway_handler_warmup_intercept_unit_test.go | 7 +- .../repository/integration_harness_test.go | 4 + .../internal/repository/scheduler_cache.go | 203 ++++++++++++++++-- .../scheduler_cache_integration_test.go | 88 ++++++++ backend/internal/repository/wire.go | 17 +- backend/internal/service/gateway_service.go | 165 +++++++------- .../service/gemini_messages_compat_service.go | 18 +- .../service/openai_gateway_service.go | 113 +++++----- .../scheduler_snapshot_hydration_test.go | 159 ++++++++++++++ deploy/config.example.yaml | 6 + 12 files changed, 631 insertions(+), 163 deletions(-) create mode 100644 backend/internal/repository/scheduler_cache_integration_test.go create mode 100644 backend/internal/service/scheduler_snapshot_hydration_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index fdc5c6ac..aff9a0ff 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -100,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { } dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig) dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) - schedulerCache := repository.NewSchedulerCache(redisClient) + schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 9b430377..ad023dc1 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -620,6 +620,10 @@ type GatewaySchedulingConfig struct { // 负载计算 LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` + // 快照桶读取时的 MGET 分块大小 + SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"` + // 快照重建时的缓存写入分块大小 + SnapshotWriteChunkSize int `mapstructure:"snapshot_write_chunk_size"` // 过期槽位清理周期(0 表示禁用) SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"` @@ -1340,6 +1344,8 @@ func setDefaults() { viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used") viper.SetDefault("gateway.scheduling.load_batch_enabled", true) + viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128) + viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) viper.SetDefault("gateway.scheduling.db_fallback_enabled", true) viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0) @@ -2001,6 +2007,12 @@ func (c *Config) Validate() error { if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 { return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive") } + if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 { + return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive") + } + if c.Gateway.Scheduling.SnapshotWriteChunkSize <= 0 { + return fmt.Errorf("gateway.scheduling.snapshot_write_chunk_size must be positive") + } if c.Gateway.Scheduling.SlotCleanupInterval < 0 { return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative") } diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 4caef955..acea3780 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -34,7 +34,12 @@ func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerB func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error { return nil } -func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) { +func (f *fakeSchedulerCache) GetAccount(_ context.Context, id int64) (*service.Account, error) { + for _, account := range f.accounts { + if account != nil && account.ID == id { + return account, nil + } + } return nil, nil } func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil } diff --git a/backend/internal/repository/integration_harness_test.go b/backend/internal/repository/integration_harness_test.go index fb9c26c4..5857fbcb 100644 --- a/backend/internal/repository/integration_harness_test.go +++ b/backend/internal/repository/integration_harness_test.go @@ -332,6 +332,10 @@ func (h prefixHook) prefixCmd(cmd redisclient.Cmder) { "hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists", "zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore": prefixOne(1) + case "mget": + for i := 1; i < len(args); i++ { + prefixOne(i) + } case "del", "unlink": for i := 1; i < len(args); i++ { prefixOne(i) diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go index 4f447e4f..35345a8b 100644 --- a/backend/internal/repository/scheduler_cache.go +++ b/backend/internal/repository/scheduler_cache.go @@ -15,19 +15,39 @@ const ( schedulerBucketSetKey = "sched:buckets" schedulerOutboxWatermarkKey = "sched:outbox:watermark" schedulerAccountPrefix = "sched:acc:" + schedulerAccountMetaPrefix = "sched:meta:" schedulerActivePrefix = "sched:active:" schedulerReadyPrefix = "sched:ready:" schedulerVersionPrefix = "sched:ver:" schedulerSnapshotPrefix = "sched:" schedulerLockPrefix = "sched:lock:" + + defaultSchedulerSnapshotMGetChunkSize = 128 + defaultSchedulerSnapshotWriteChunkSize = 256 ) type schedulerCache struct { - rdb *redis.Client + rdb *redis.Client + mgetChunkSize int + writeChunkSize int } func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache { - return &schedulerCache{rdb: rdb} + return newSchedulerCacheWithChunkSizes(rdb, defaultSchedulerSnapshotMGetChunkSize, defaultSchedulerSnapshotWriteChunkSize) +} + +func newSchedulerCacheWithChunkSizes(rdb *redis.Client, mgetChunkSize, writeChunkSize int) service.SchedulerCache { + if mgetChunkSize <= 0 { + mgetChunkSize = defaultSchedulerSnapshotMGetChunkSize + } + if writeChunkSize <= 0 { + writeChunkSize = defaultSchedulerSnapshotWriteChunkSize + } + return &schedulerCache{ + rdb: rdb, + mgetChunkSize: mgetChunkSize, + writeChunkSize: writeChunkSize, + } } func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) { @@ -65,9 +85,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul keys := make([]string, 0, len(ids)) for _, id := range ids { - keys = append(keys, schedulerAccountKey(id)) + keys = append(keys, schedulerAccountMetaKey(id)) } - values, err := c.rdb.MGet(ctx, keys...).Result() + values, err := c.mgetChunked(ctx, keys) if err != nil { return nil, false, err } @@ -100,14 +120,11 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul versionStr := strconv.FormatInt(version, 10) snapshotKey := schedulerSnapshotKey(bucket, versionStr) - pipe := c.rdb.Pipeline() - for _, account := range accounts { - payload, err := json.Marshal(account) - if err != nil { - return err - } - pipe.Set(ctx, schedulerAccountKey(strconv.FormatInt(account.ID, 10)), payload, 0) + if err := c.writeAccounts(ctx, accounts); err != nil { + return err } + + pipe := c.rdb.Pipeline() if len(accounts) > 0 { // 使用序号作为 score,保持数据库返回的排序语义。 members := make([]redis.Z, 0, len(accounts)) @@ -117,7 +134,13 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul Member: strconv.FormatInt(account.ID, 10), }) } - pipe.ZAdd(ctx, snapshotKey, members...) + for start := 0; start < len(members); start += c.writeChunkSize { + end := start + c.writeChunkSize + if end > len(members) { + end = len(members) + } + pipe.ZAdd(ctx, snapshotKey, members[start:end]...) + } } else { pipe.Del(ctx, snapshotKey) } @@ -151,20 +174,15 @@ func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Accoun if account == nil || account.ID <= 0 { return nil } - payload, err := json.Marshal(account) - if err != nil { - return err - } - key := schedulerAccountKey(strconv.FormatInt(account.ID, 10)) - return c.rdb.Set(ctx, key, payload, 0).Err() + return c.writeAccounts(ctx, []service.Account{*account}) } func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error { if accountID <= 0 { return nil } - key := schedulerAccountKey(strconv.FormatInt(accountID, 10)) - return c.rdb.Del(ctx, key).Err() + id := strconv.FormatInt(accountID, 10) + return c.rdb.Del(ctx, schedulerAccountKey(id), schedulerAccountMetaKey(id)).Err() } func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { @@ -179,7 +197,7 @@ func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]t ids = append(ids, id) } - values, err := c.rdb.MGet(ctx, keys...).Result() + values, err := c.mgetChunked(ctx, keys) if err != nil { return err } @@ -198,7 +216,12 @@ func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]t if err != nil { return err } + metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(*account)) + if err != nil { + return err + } pipe.Set(ctx, keys[i], updated, 0) + pipe.Set(ctx, schedulerAccountMetaKey(strconv.FormatInt(ids[i], 10)), metaPayload, 0) } _, err = pipe.Exec(ctx) return err @@ -256,6 +279,10 @@ func schedulerAccountKey(id string) string { return schedulerAccountPrefix + id } +func schedulerAccountMetaKey(id string) string { + return schedulerAccountMetaPrefix + id +} + func ptrTime(t time.Time) *time.Time { return &t } @@ -276,3 +303,137 @@ func decodeCachedAccount(val any) (*service.Account, error) { } return &account, nil } + +func (c *schedulerCache) writeAccounts(ctx context.Context, accounts []service.Account) error { + if len(accounts) == 0 { + return nil + } + + pipe := c.rdb.Pipeline() + pending := 0 + flush := func() error { + if pending == 0 { + return nil + } + if _, err := pipe.Exec(ctx); err != nil { + return err + } + pipe = c.rdb.Pipeline() + pending = 0 + return nil + } + + for _, account := range accounts { + fullPayload, err := json.Marshal(account) + if err != nil { + return err + } + metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(account)) + if err != nil { + return err + } + + id := strconv.FormatInt(account.ID, 10) + pipe.Set(ctx, schedulerAccountKey(id), fullPayload, 0) + pipe.Set(ctx, schedulerAccountMetaKey(id), metaPayload, 0) + pending++ + if pending >= c.writeChunkSize { + if err := flush(); err != nil { + return err + } + } + } + + return flush() +} + +func (c *schedulerCache) mgetChunked(ctx context.Context, keys []string) ([]any, error) { + if len(keys) == 0 { + return []any{}, nil + } + + out := make([]any, 0, len(keys)) + chunkSize := c.mgetChunkSize + if chunkSize <= 0 { + chunkSize = defaultSchedulerSnapshotMGetChunkSize + } + for start := 0; start < len(keys); start += chunkSize { + end := start + chunkSize + if end > len(keys) { + end = len(keys) + } + part, err := c.rdb.MGet(ctx, keys[start:end]...).Result() + if err != nil { + return nil, err + } + out = append(out, part...) + } + return out, nil +} + +func buildSchedulerMetadataAccount(account service.Account) service.Account { + return service.Account{ + ID: account.ID, + Name: account.Name, + Platform: account.Platform, + Type: account.Type, + Concurrency: account.Concurrency, + Priority: account.Priority, + RateMultiplier: account.RateMultiplier, + Status: account.Status, + LastUsedAt: account.LastUsedAt, + ExpiresAt: account.ExpiresAt, + AutoPauseOnExpired: account.AutoPauseOnExpired, + Schedulable: account.Schedulable, + RateLimitedAt: account.RateLimitedAt, + RateLimitResetAt: account.RateLimitResetAt, + OverloadUntil: account.OverloadUntil, + TempUnschedulableUntil: account.TempUnschedulableUntil, + TempUnschedulableReason: account.TempUnschedulableReason, + SessionWindowStart: account.SessionWindowStart, + SessionWindowEnd: account.SessionWindowEnd, + SessionWindowStatus: account.SessionWindowStatus, + Credentials: filterSchedulerCredentials(account.Credentials), + Extra: filterSchedulerExtra(account.Extra), + } +} + +func filterSchedulerCredentials(credentials map[string]any) map[string]any { + if len(credentials) == 0 { + return nil + } + keys := []string{"model_mapping", "api_key", "project_id", "oauth_type"} + filtered := make(map[string]any) + for _, key := range keys { + if value, ok := credentials[key]; ok && value != nil { + filtered[key] = value + } + } + if len(filtered) == 0 { + return nil + } + return filtered +} + +func filterSchedulerExtra(extra map[string]any) map[string]any { + if len(extra) == 0 { + return nil + } + keys := []string{ + "mixed_scheduling", + "window_cost_limit", + "window_cost_sticky_reserve", + "max_sessions", + "session_idle_timeout_minutes", + } + filtered := make(map[string]any) + for _, key := range keys { + if value, ok := extra[key]; ok && value != nil { + filtered[key] = value + } + } + if len(filtered) == 0 { + return nil + } + return filtered +} diff --git a/backend/internal/repository/scheduler_cache_integration_test.go b/backend/internal/repository/scheduler_cache_integration_test.go new file mode 100644 index 00000000..134a6a07 --- /dev/null +++ b/backend/internal/repository/scheduler_cache_integration_test.go @@ -0,0 +1,88 @@ +//go:build integration + +package repository + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T) { + ctx := context.Background() + rdb := testRedis(t) + cache := NewSchedulerCache(rdb) + + bucket := service.SchedulerBucket{GroupID: 2, Platform: service.PlatformGemini, Mode: service.SchedulerModeSingle} + now := time.Now().UTC().Truncate(time.Second) + limitReset := now.Add(10 * time.Minute) + overloadUntil := now.Add(2 * time.Minute) + tempUnschedUntil := now.Add(3 * time.Minute) + windowEnd := now.Add(5 * time.Hour) + + account := service.Account{ + ID: 101, + Name: "gemini-heavy", + Platform: service.PlatformGemini, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 7, + LastUsedAt: &now, + Credentials: map[string]any{ + "api_key": "gemini-api-key", + "access_token": "secret-access-token", + "project_id": "proj-1", + "oauth_type": "ai_studio", + "model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}, + "huge_blob": strings.Repeat("x", 4096), + }, + Extra: map[string]any{ + "mixed_scheduling": true, + "window_cost_limit": 12.5, + "window_cost_sticky_reserve": 8.0, + "max_sessions": 4, + "session_idle_timeout_minutes": 11, + "unused_large_field": strings.Repeat("y", 4096), + }, + RateLimitResetAt: &limitReset, + OverloadUntil: &overloadUntil, + TempUnschedulableUntil: &tempUnschedUntil, + SessionWindowStart: &now, + SessionWindowEnd: &windowEnd, + SessionWindowStatus: "active", + } + + require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account})) + + snapshot, hit, err := cache.GetSnapshot(ctx, bucket) + require.NoError(t, err) + require.True(t, hit) + require.Len(t, snapshot, 1) + + got := snapshot[0] + require.NotNil(t, got) + require.Equal(t, "gemini-api-key", got.GetCredential("api_key")) + require.Equal(t, "proj-1", got.GetCredential("project_id")) + require.Equal(t, "ai_studio", got.GetCredential("oauth_type")) + require.NotEmpty(t, got.GetModelMapping()) + require.Empty(t, got.GetCredential("access_token")) + require.Empty(t, got.GetCredential("huge_blob")) + require.Equal(t, true, got.Extra["mixed_scheduling"]) + require.Equal(t, 12.5, got.GetWindowCostLimit()) + require.Equal(t, 8.0, got.GetWindowCostStickyReserve()) + require.Equal(t, 4, got.GetMaxSessions()) + require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes()) + require.Nil(t, got.Extra["unused_large_field"]) + + full, err := cache.GetAccount(ctx, account.ID) + require.NoError(t, err) + require.NotNil(t, full) + require.Equal(t, "secret-access-token", full.GetCredential("access_token")) + require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob")) +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 657e3ed6..d3adb4a0 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -47,6 +47,21 @@ func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.Ses return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes) } +// ProvideSchedulerCache 创建调度快照缓存,并注入快照分块参数。 +func ProvideSchedulerCache(rdb *redis.Client, cfg *config.Config) service.SchedulerCache { + mgetChunkSize := defaultSchedulerSnapshotMGetChunkSize + writeChunkSize := defaultSchedulerSnapshotWriteChunkSize + if cfg != nil { + if cfg.Gateway.Scheduling.SnapshotMGetChunkSize > 0 { + mgetChunkSize = cfg.Gateway.Scheduling.SnapshotMGetChunkSize + } + if cfg.Gateway.Scheduling.SnapshotWriteChunkSize > 0 { + writeChunkSize = cfg.Gateway.Scheduling.SnapshotWriteChunkSize + } + } + return newSchedulerCacheWithChunkSizes(rdb, mgetChunkSize, writeChunkSize) +} + // ProviderSet is the Wire provider set for all repositories var ProviderSet = wire.NewSet( NewUserRepository, @@ -92,7 +107,7 @@ var ProviderSet = wire.NewSet( NewRedeemCache, NewUpdateCache, NewGeminiTokenCache, - NewSchedulerCache, + ProvideSchedulerCache, NewSchedulerOutboxRepository, NewProxyLatencyCache, NewTotpCache, diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index a4733649..8b0bdc2a 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1192,12 +1192,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { - return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + if err != nil { + return nil, err + } + return s.hydrateSelectedAccount(ctx, account) } // antigravity 分组、强制平台模式或无分组使用单平台选择 // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 - return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + if err != nil { + return nil, err + } + return s.hydrateSelectedAccount(ctx, account) } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. @@ -1273,11 +1281,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro localExcluded[account.ID] = struct{}{} // 排除此账号 continue // 重新选择 } - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } // 对于等待计划的情况,也需要先检查会话限制 @@ -1289,26 +1293,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) } } @@ -1455,11 +1453,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } - return &AccountSelectionResult{ - Account: stickyAccount, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil) } } @@ -1570,11 +1564,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } - return &AccountSelectionResult{ - Account: item.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil) } } @@ -1587,15 +1577,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } - return &AccountSelectionResult{ - Account: item.account, - WaitPlan: &AccountWaitPlan{ - AccountID: item.account.ID, - MaxConcurrency: item.account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{ + AccountID: item.account.ID, + MaxConcurrency: item.account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } // 所有路由账号会话限制都已满,继续到 Layer 2 回退 } @@ -1631,11 +1618,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 } else { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + if s.cache != nil { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) + } + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } } @@ -1647,15 +1633,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // 会话限制已满,继续到 Layer 2 // Session limit full, continue to Layer 2 } else { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } } } @@ -1714,7 +1697,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { - if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { + if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil { + return nil, legacyErr + } else if ok { return result, nil } } else { @@ -1753,11 +1738,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } - return &AccountSelectionResult{ - Account: selected.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil) } } @@ -1780,20 +1761,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, acc, sessionHash) { continue // 会话限制已满,尝试下一个账号 } - return &AccountSelectionResult{ - Account: acc, - WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) } return nil, ErrNoAvailableAccounts } -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) @@ -1808,15 +1786,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) } - return &AccountSelectionResult{ - Account: acc, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, true + selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil) + if err != nil { + return nil, false, err + } + return selection, true, nil } } - return nil, false + return nil, false, nil } func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { @@ -2431,6 +2409,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in return s.accountRepo.GetByID(ctx, accountID) } +func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { + if account == nil || s.schedulerSnapshot == nil { + return account, nil + } + hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) + if err != nil { + return nil, err + } + if hydrated == nil { + return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID) + } + return hydrated, nil +} + +func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) { + hydrated, err := s.hydrateSelectedAccount(ctx, account) + if err != nil { + return nil, err + } + return &AccountSelectionResult{ + Account: hydrated, + Acquired: acquired, + ReleaseFunc: release, + WaitPlan: waitPlan, + }, nil +} + // filterByMinPriority 过滤出优先级最小的账号集合 func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { if len(accounts) == 0 { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 32bf21c0..5a9490f3 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -137,7 +137,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL) } - return selected, nil + return s.hydrateSelectedAccount(ctx, selected) } // resolvePlatformAndSchedulingMode 解析目标平台和调度模式。 @@ -416,6 +416,20 @@ func (s *GeminiMessagesCompatService) getSchedulableAccount(ctx context.Context, return s.accountRepo.GetByID(ctx, accountID) } +func (s *GeminiMessagesCompatService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { + if account == nil || s.schedulerSnapshot == nil { + return account, nil + } + hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) + if err != nil { + return nil, err + } + if hydrated == nil { + return nil, fmt.Errorf("selected gemini account %d not found during hydration", account.ID) + } + return hydrated, nil +} + func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, error) { if s.schedulerSnapshot != nil { accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) @@ -546,7 +560,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont if selected == nil { return nil, errors.New("no available Gemini accounts") } - return selected, nil + return s.hydrateSelectedAccount(ctx, selected) } func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 2623d773..dbc53869 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1243,7 +1243,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) } - return selected, nil + return s.hydrateSelectedAccount(ctx, selected) } // tryStickySessionHit 尝试从粘性会话获取账号。 @@ -1408,35 +1408,25 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex } result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) if err == nil && result.Acquired { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) } accounts, err := s.listSchedulableAccounts(ctx, groupID) @@ -1476,24 +1466,17 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) } } } @@ -1552,11 +1535,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if sessionHash != "" { _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) } - return &AccountSelectionResult{ - Account: fresh, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil) } } } else { @@ -1609,11 +1588,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if sessionHash != "" { _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) } - return &AccountSelectionResult{ - Account: fresh, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil) } } } @@ -1629,15 +1604,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) { continue } - return &AccountSelectionResult{ - Account: fresh, - WaitPlan: &AccountWaitPlan{ - AccountID: fresh.ID, - MaxConcurrency: fresh.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil + return s.newSelectionResult(ctx, fresh, false, nil, &AccountWaitPlan{ + AccountID: fresh.ID, + MaxConcurrency: fresh.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) } return nil, ErrNoAvailableAccounts @@ -1732,6 +1704,33 @@ func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accoun return account, nil } +func (s *OpenAIGatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { + if account == nil || s.schedulerSnapshot == nil { + return account, nil + } + hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) + if err != nil { + return nil, err + } + if hydrated == nil { + return nil, fmt.Errorf("selected openai account %d not found during hydration", account.ID) + } + return hydrated, nil +} + +func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) { + hydrated, err := s.hydrateSelectedAccount(ctx, account) + if err != nil { + return nil, err + } + return &AccountSelectionResult{ + Account: hydrated, + Acquired: acquired, + ReleaseFunc: release, + WaitPlan: waitPlan, + }, nil +} + func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { if s.cfg != nil { return s.cfg.Gateway.Scheduling diff --git a/backend/internal/service/scheduler_snapshot_hydration_test.go b/backend/internal/service/scheduler_snapshot_hydration_test.go new file mode 100644 index 00000000..5c0b289b --- /dev/null +++ b/backend/internal/service/scheduler_snapshot_hydration_test.go @@ -0,0 +1,159 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" +) + +type snapshotHydrationCache struct { + snapshot []*Account + accounts map[int64]*Account +} + +func (c *snapshotHydrationCache) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) { + return c.snapshot, true, nil +} + +func (c *snapshotHydrationCache) SetSnapshot(ctx context.Context, bucket SchedulerBucket, accounts []Account) error { + return nil +} + +func (c *snapshotHydrationCache) GetAccount(ctx context.Context, accountID int64) (*Account, error) { + if c.accounts == nil { + return nil, nil + } + return c.accounts[accountID], nil +} + +func (c *snapshotHydrationCache) SetAccount(ctx context.Context, account *Account) error { + return nil +} + +func (c *snapshotHydrationCache) DeleteAccount(ctx context.Context, accountID int64) error { + return nil +} + +func (c *snapshotHydrationCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { + return nil +} + +func (c *snapshotHydrationCache) TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error) { + return true, nil +} + +func (c *snapshotHydrationCache) ListBuckets(ctx context.Context) ([]SchedulerBucket, error) { + return nil, nil +} + +func (c *snapshotHydrationCache) GetOutboxWatermark(ctx context.Context) (int64, error) { + return 0, nil +} + +func (c *snapshotHydrationCache) SetOutboxWatermark(ctx context.Context, id int64) error { + return nil +} + +func TestOpenAISelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) { + cache := &snapshotHydrationCache{ + snapshot: []*Account{ + { + ID: 1, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 1, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-4": "gpt-4", + }, + }, + }, + }, + accounts: map[int64]*Account{ + 1: { + ID: 1, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 1, + Credentials: map[string]any{ + "api_key": "sk-live", + "model_mapping": map[string]any{"gpt-4": "gpt-4"}, + }, + }, + }, + } + + schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, nil, nil, nil) + groupID := int64(2) + svc := &OpenAIGatewayService{ + schedulerSnapshot: schedulerSnapshot, + cache: &stubGatewayCache{}, + } + + selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if selection == nil || selection.Account == nil { + t.Fatalf("expected selected account") + } + if got := selection.Account.GetOpenAIApiKey(); got != "sk-live" { + t.Fatalf("expected hydrated api key, got %q", got) + } +} + +func TestGatewaySelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) { + cache := &snapshotHydrationCache{ + snapshot: []*Account{ + { + ID: 9, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 1, + }, + }, + accounts: map[int64]*Account{ + 9: { + ID: 9, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 1, + Credentials: map[string]any{ + "api_key": "anthropic-live-key", + }, + }, + }, + } + + schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, nil, nil, nil) + svc := &GatewayService{ + schedulerSnapshot: schedulerSnapshot, + cache: &mockGatewayCacheForPlatform{}, + cfg: testConfig(), + } + + result, err := svc.SelectAccountWithLoadAwareness(context.Background(), nil, "", "claude-3-5-sonnet-20241022", nil, "", 0) + if err != nil { + t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) + } + if result == nil || result.Account == nil { + t.Fatalf("expected selected account") + } + if got := result.Account.GetCredential("api_key"); got != "anthropic-live-key" { + t.Fatalf("expected hydrated api key, got %q", got) + } +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 8f60acd5..45440761 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -347,6 +347,12 @@ gateway: # Enable batch load calculation for scheduling # 启用调度批量负载计算 load_batch_enabled: true + # Snapshot bucket MGET chunk size + # 调度快照分桶读取时的 MGET 分块大小 + snapshot_mget_chunk_size: 128 + # Snapshot bucket write chunk size + # 调度快照重建写入时的分块大小 + snapshot_write_chunk_size: 256 # Slot cleanup interval (duration) # 并发槽位清理周期(时间段) slot_cleanup_interval: 30s -- GitLab From 155d3474d6b82bc0fb0d83abb574740efb6fcbd2 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 9 Apr 2026 09:21:37 +0800 Subject: [PATCH 02/52] chore: update Sponsors --- README.md | 6 ++++++ README_CN.md | 6 ++++++ README_JA.md | 6 ++++++ assets/partners/logos/ylscode.png | Bin 0 -> 26820 bytes 4 files changed, 18 insertions(+) create mode 100644 assets/partners/logos/ylscode.png diff --git a/README.md b/README.md index 2f73e92a..25bef473 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,12 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot silkapi Thanks to SilkAPI for sponsoring this project! SilkAPI is a relay service built on Sub2API, specializing in providing high-speed and stable Codex API relay. + + +ylscode +Thanks to YLS Code for sponsoring this project! YLS Code is dedicated to building secure enterprise-grade Coding Agent productivity services, offering stable and fast Codex / Claude / Gemini subscription services along with pay-as-you-go API options for flexible choices. Register now for a limited-time 3-day Codex trial bonus! + + ## Ecosystem diff --git a/README_CN.md b/README_CN.md index a0c3fd4b..003a9530 100644 --- a/README_CN.md +++ b/README_CN.md @@ -73,6 +73,12 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 silkapi 感谢 丝绸API 赞助了本项目! 丝绸API 是基于 Sub2API 搭建的中转服务,专注于提供 Codex 高速稳定API中转。 + + +silkapi +感谢 伊莉思Code 赞助了本项目! 伊莉思Code 致力于构建安全的企业级Coding Agent生产力服务,提供稳定快速的 Codex / Claude / Gemini 订阅服务与即用即付API多种方案灵活选择,限时注册赠送 3 天 Codex 试用福利! + + ## 生态项目 diff --git a/README_JA.md b/README_JA.md index bd69e06b..818e944b 100644 --- a/README_JA.md +++ b/README_JA.md @@ -73,6 +73,12 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを silkapi SilkAPI のご支援に感謝します!SilkAPI は Sub2API をベースに構築された中継サービスで、高速かつ安定した Codex API 中継の提供に特化しています。 + + +ylscode +YLS Code のご支援に感謝します!YLS Code は安全なエンタープライズグレードの Coding Agent 生産性サービスの構築に取り組んでおり、安定かつ高速な Codex / Claude / Gemini サブスクリプションサービスと従量課金 API の柔軟なプランを提供しています。期間限定で新規登録者に 3 日間の Codex 試用特典をプレゼント中! + + ## エコシステム diff --git a/assets/partners/logos/ylscode.png b/assets/partners/logos/ylscode.png new file mode 100644 index 0000000000000000000000000000000000000000..4d374f04c2ffea503d71e0022b2937000d206126 GIT binary patch literal 26820 zcmZ^KbzD?k*Y?mbGz=l#CEYo6N{1lQjda&YcQ+D}f`XFL-60An5=spWrPK^441)L` z@8|x#-}C+P{xNXo%$c+IS$oa8u63=oQw;SriNSPW5C}x9t)*rR0%4c}$G-Tuz+dIn zWLDrGMu4%V3aDw0;T!M<+f`Xl83bz0BD}T70p1h%YCQ@7fk?*wzA&bIUO53@NOCu| z2(-}Cm38#-60~>nac~w4_3{O-27%-iLw)TXJ)8qs9GqRIR1ND zhz%<5@8lwDtfukr6~H?MHn+e)Us)lckdP3;5HUd?e^((985tQNVNoGbQ32ozfq*dY zK>JVu?*R6HZ=mKJ;OOt}8|d!i&GPq#_6|NlfeLIvLGDhnF7_@G4o*%I0uJ_KP68q# zPWA%!PGTYgPGVwCQW7H4qAns1Z2t})=q`xUEIu#`S{dAEcpS76FhA(x0*g{RCC}lW@98SR zy)idf4!Fp;`n^Z*qcO-W#wMD}5PCe%`+_C9Js&&7OM#2Hpy#k6D<=Md7ttK|<{BLo z&nfNt_6|iChs3Ca-3uYx8J z8W82R3l^6S$XM?ae%tDG^#t|)5lVG+vvT%gS<`ALlwnEDF=QaH2f0v9cU31%gSD!AU>{7v`{ z^jOg22$SC2VDkG#7FILke%`59tbtkdo2KgQ&z1LIOm%!9(|Kit%gXJkpl#CnQL#cE zz6z;eE_X77mQ(8U?;!nT9iW%AAMSP~aCZq=$q4aE;TJ6hGoMv&o%j?D+ImIn$GJX9 zi7?0Qf~&HbU*j=%X*!(u>5iA-BaO--GVcM?x);%7gkf>6*ImpHtgLB}FF zKl8+8Q+AfWs9!y-FktfaOr6hCTIh0|_*uhCF`0>FDVA=PS7TEJtI7L_%kWOU$8&zlo)^|Rx_Ah7P+)<= zrYj|%sfuUV=MKjB=Z4O>>B6AL1`DRJ!;1Nwv`5Y`xBv$RsgA?>M>oKuD+vWhUEof9 zF~f@6(bC}CzW%?QOONB2>D!=cE<%jA8VqEi=B zGr1pdOZaGwwrrl%piKfsa*W50n`z)cGeUzj_4l@3KwCv_Ju|hVDa!IFebZFN8T8I6 z8x`}Gebo{g{f(givMLHXRJKAJPil29Bp!;!$ERqBg{i1MG_N8XvSFUd$Nf#2b)NL` ztu4+Q-9x2muHERXDEv>j8>I%_E2I^K#d^(7-&J*?QKEa>d6zQ>Ng?r#+W}|Ozj#LsFW zfa0wP6n_n1Vc)j#6b-dfP8oLr0T1+mJRZCvQ+y7XC`-x74F;vccF*{k zhs%6+FF8*aF}%bcVwy)}oZr}+^YF?4HsyK6?YP0fJk%u-(i_?T zOqVhK*8>DnTaXNQr@%|L7`NxvxL%KgIUW`%CfozcGcV!_tVCR zQsj>lIa568=Uz}Qs(sEcZ5FpE(Ihe!G{|WYN z`l5SZQv_&6{*CoXK&(eKc@>aJ%Gj=C8=GOUk3Hlv@qPNLGlu^}L`{aQjo@%b9|4E$ z4E*&DM&#W;K!b%ey3=2mi-5|QC!t|XO}3&L9ShpJf z4P*KN4IrDnUvs)jxoQGDk^ zPekpD!~-ABSm146#2yrNNB#VOH~3y7aQNon;%3T&Xqa0I7Gog||318i;Q%F#<1k?p zx`?1l%F@}5ZvQ_K;bNnInxNxnU;@j_*V9K0ze9zj3o{Nq-^bUwC5%Op{~3=2Ab@;NX}ZaExRn>GC_hsP`_DASz9nC;I*Q-qtc& z*SGIso8t$eyD8{D!q#o(JKbkx1e3i6@8|N)YhsX*9bWFIDVbsrh=0&MDrQD} zLBC&gE?2aN0T`%!mK*T{o%4YOn z1e&Ds2RnHLPuh<}1&WRPs10m_B*7OCM<(#I@j>vz83vyAXe`dg`q~e?+j)?mEkliFR z`Lcc_s%Sw)d3F+c^;L`*T#p5;^Xn@sZL7iPOt^*Q?u@wa=T+Do4ntd~i1c4L_dv@} z=hk;iD)~UKgfb04bQ9c7_w&G>Y(4NVzFsLy1#e&I7ZZs>eS@(d&*BLy#H>Jn%pVk8 z0OT&_J|ffZ$mntPee6eisv}|jA@E<^($pX+I+Z^vmHvO^L&K6*>?o8~GK${}%dspZ zyDhMWsU=`WNHXIr)KXg(rX;#oy3oBUw*LHHkghJK(7pPBK&kP*>R%*a;S!~vS8Po* z_{v?uyh#p_yO@pH@fI9dWZ&wH>BIf$rzi31ED$`r94q{Dg59#cVHp18$d@%XD9wsU zaW524n8&t-pY~qef#H+!h%0D)g0WdT+RG(}l4&517nB{Bbz6`n=psrb%_5(iP1xTV zZP2YrpsVw-irN%}NP%elfP4Z+(8qoQ2`s98ICiL6o7CvSnvE=y>Sq|eDAJXhI!Hp| zm%{&rs15{E28X1GAlZ=ZsI+wQSc0#D!oTs=4w`(@6d=0ykG(lUy{Rw{8)Dn(df9_m zwgZ$@U^%`&cVUr#LK`c`^0|OdSwFh-XD$`D}kj>XyepfBZyI@CdAIY<4X72@N9|d_Dt#u;4LY(-z1L(}!l` zC^V=HT%#2*XRC#ICan14`ZVfL>|znV(o%E_S)*|i&M=p2HYOec`+*G&osXmCtI21! zAR#Ujdgyf_$0ADwcg~@v|HysB*O&1@d9d|OYE)0?2!}~b=j*YeFbd2kU6chaY9-5% z7F1ofHFiauPlzZE$RD|#rENtvEQ6Roo|q8aQ4k+i*2Dt3|O0Vk6+12GBm;Vea*#8P^&-=HbA>JQ|rQ5KppP zc!j|$t)7}I)b(w*&CF+%3=8QR0a}GX;+;}3T@4>(f+n8Z9>&!nT&}hK^y0=(T8d&} zgFh28iCkPGWvhBs{@~d2X82~BOssfv?_c;5+5b->^$&dgGmrXaT{-V+nk6$giD(}p zymYfi$hYSm4{fp{(pNq6{Bt@Q50yEQQg_11f{R6Ft z*j28Fq9<8d%9Fh^&_7J~*|d04e{Vz_pf;PSV&mttXdFb23u5LE$uQ4Y;nq6fZRV<+pI&15=4XOd$eT zXa^~=$z25JME#$|B@Yvsm|ovO_EIh;UI|{h2}TpfZ%jNCLI;Nn`a1S zK1@T8ey4)Dd}AAIi=>)g(s3)fQ|1=&7E0f~;oAQwDo_M%XhQXJnmZ7~%xm%__^UazYA zwrO-l3zUWrwy$;S$R$Mp{wPvEEJ#9QZU8ua^nW0JllmM(*8tu@`upNP)75-|xnR0% z{v9@(@-lsjW;{r%u9KvhJQq82#zRaPddAcb%#p6WMFrDKJ0aPYVy6a#&$?PLgD7BM zw21*(&yE%?Q~TMX=pZWDrluUpsrl8__9f`^uVygklP9a`V}n*0XXRU1ao|m#79n`_ zydW&s0^9i4KhLi=HSy80G>R0E@qgSW=o?%>R*yaZ}IPDie+3b)|oO ztkSA^7DVixF#w)3d54z&6Z#`Zm(ev^3#ys4g0GD^wcBw|9^jg)B={d*5L?2dU#P?= zU{u&Y;&IakoG`>`ixA!-|PA{iP#XJVlyQz1t2HI_O z%(4TZ>6)wUnoLG8(I3US*3;5>+I2hABG^9Q4TDN)-2Qth-K$O{(f0|DwbA^bfEj)p zMGhoC#O|Ce`Vxr)E(Ps}{ z<}90Tt7{3cG}b@Ud)XpZG;--lxAxZkc+S;uA?TSngVc02{WtnJb90;1FFjGj@9|nU zg7oX;aSsA?bNQ6@9V=KrQcVm--fECkU#qduI!(NS_kbN1ocOT*jO&v?f1r~6)Ha9X z-edSW$0*E9co?RpH*9n7lfqp)nZY!8zCDko&a9;qNjkczUfdp9P$&MxjtmTz^5#uX zNW-%DCXOPUVjjz@4dQ@&4CzO#NP(1}CVbq)X2y!KfAmbZ4%zlm!~&$h($Kx^=`gUU zY;083cX8+W!)?L)-io_WCo4pl$IfC$!$eHCKP-~kXAPQ^IT5)FA>;o{kKP?+(|@Sa zjBQfsWwdTMy~8NC-^yN^c(M4*5Fh&1m$<^lLS^oeI}yH`xydSr|C500*{_8OneM19 z)*UOF&#>sUjMX^tCPjOqGF?u|s$b`TywHVQTDdkc$EUcwU-N`zl5P6>_Rtbi`1`Z6 z3TkmF`A`cT8_JZ+3$TVN1GEY=->IAu{9XTm+hC`&gE?2$Q2#fxXZkLqU_l}sRBgw7 z^zA$6QzJzpF!%|E)P53b@Z=`P^p?hkj2dUKKbUIh$2z4j5AgtsY-kYUFOBeIWX4KU z852e|q-fvH>2=vj=M3l75;F)Cp>*rxr&5nB+{9OhAJ5}D&Yr08$<+`<=%fdjI~c|G z*V#+g$AM4NA3PU~d~zn(@46CfjtinWBEG%KJMevDs{EdolY|a`Z!{KGzPxzOc=CgvxaD7YSgTqpI_unb_8cA2 zInOV&c=dg9rxrT+k7t})<%)i`f+WUN7_MIhPx4i?09xpwE&$uOD$O*|Siw$_DyDCCxc z>2V_q!j@k~U$oun35KeM1zLZlV6on zS{F1tFM{!OqJnCzw%)f)Hy&jz=yo>u_9tkCPGq?oV2GVo2 z@K-KN2fmw=83sJzke^|OC!0C8u=FO=5A{hUIveafB_Hehrc zHgDgbot>692uV*JFhq9IYZnL%j;L}(Fp-O3zgGg0Mc8HSETiNj&*2P_N3{{in7E7X z2p3wpbKcW$p0dx!A8>)u^{N5rX~sTO9^OGqlr-U(gF!H?FKz+PRYlj>ENAgcslKcf zu}-kp1G=wI;Ch6SM+0yCb#?G9cka*1ul~cUaIeK-Z-KLVyN6A^mzEW zY92lCn!YvkN5w> zeet;#3*8;q3Pwx!_ge%Q8QMko2{j)$m%U!fCjQB8LrPX%Y5z?6aU;SiV@u!M!@?(I z)CGV;YI?mg^3~_{%dM+!;Ir_w6G46RC(`DPc%2tZvU3l=yU}jT9lq+ueBDSg*kLnL zU09*iS zE@hrb@?OihS{IH>$wwPm#HZwE0=_gNwAG8%T*jngY=j-MtIW*fJY#B3(AJw+*3_G| zqCDih(pjlb4p;m1m>c|1dWIvWF$qaTYYHrRSmMfq=P$1qa* za`&~45$NDM)h())@^syVBv9N?$X|%j!cjQZd{WX6>{q*Z{D~X$3%fo zr~0>W^dF~D6)v{Sul6vyU07#`{Ef`eN^%jgdb5IhpqaI!){)vZKaU7nTs&1u)aKQo zAUK}c>u#ob)BaW3x>3ahVyKhz(CqlVi%1#%rQ)9twX^6w*Yb&E0Vy8g=Q<9WiJZZ7 z4bHV5aLwTNh@>u@8$qEZdh}rr4#R!3oTdC>&-6~`?L-*K*_Q!gxwdc78wY|y+mtla zIYw~-76c%5)^1}T`}(Cs4Xxq4j9{9fu%pO+kC26|dD}-56=haR9cwr@S|p`U{nu8> z?oPI!ujY~7+5IkZq(QD66bT<6)ieeY-!4D6-2NGWz}Wa3_3s+rhT7F1|E$6tzJ1FN z9>a_bUBQO>x-pr2L#LR=&s$kcFf&=J)jPIeXh(ZQe2#P$uWIqx^w2?PLduD`@uu+9 z4#2p50w&F*KMiaK2<6TRZIvmDWod)`H~lSOWEOp3_LyIaR$s2&CC(hL$DvTk$-;NEBp2Z??EMH?A}9q z(ne|&{0TW}-x@BH&D!wq`$wToR_bHIt(&KIwM&yv`rLuc-Ww znO3_qPC{|J;y7V?s&j)6X5D4m=pv4Z*=rkkq{A-@SH65^76~Q6omY%A?)Q#_{8`!d zD%GzuKk{1dvQVk@PzoYFxrWEH)syW85?{wF7!;=E2t_hr2c=F^5!r~IkorTTQ7AfA zf)IgJyDf>NTgR^-Vc*AZwtm?UCEuG)2Nr(9TlunK35wmN0n`BKZ65nGU4rOB zXM3$h)SY!7vX>VcIeQ{`w}0+j*7Q&rB(6(I|D5FAhTuDEtk)-!b$Pt~E>FNchJU`B z)Kc8JY>NWa_h&F3ui2P2!f~~nc=E`tJ~nEwGdBO7ja2*zFHQoE+0Zitz^qqbIiJ}NlI47bs1F4}@-4>W zc@v`L7Oo)?XR9(vXQdUR_qjBES9W<$%jx2j1SNBB&Y8KO zwHI#mZv;O#X@&|f#Ytz$G|GWSIv90c#l~r2?mjPn<-`&YUHL%ZecfC1#DCa;2;5Yf?71 zS2@~3)AFl#ijD6rH30$vTVunk zKs~DliI!D2fI&=B9yymae)dBEX-BJ%oQ;S~s_Ix61Kk{A*pP(qC-v!L-Zq(6RfKAV zh5@~aJ)YpTHs1GZYn@hLVEv19G3+?%9V6xzpU>B|C=~7F^1dKgg1*h?WUw|vzD?@f z&eXp4EIN?tMS)So1Nv{)mzLaB1z{g0aH|Fp*0kg?luox30^8;>sV|A|VFnpq&7gILh zP9q$4^TYc#X`t&@hxD7ERkgbgZ#TB6_$UngTQMGW9V1j?eOP}mTX@QF>c@$*@~0%s zL;e`|%+>;jW$c)q@RLG$V3;vQfknyFdborLu_ZWedQE&tEU0oLb|#(>P)K=VmWVJDt`Z zcYWB#>-u1)wxWzb;-*6LJ|(}#@e%qcEe*YBLIa%Z3;uv`MzDLfe~mmri4m?KXp>y!%YQ%h{s)>mmG$1J?YTCFG$$ zvEG=2U*o~y`HzMZ>J8zrWvGArobvw$YVX^ZbGw3+*?;UNOKI=9;g$LFra2L{rx^xh9 zhn#JD6jp6{tPQGW_>d7*SMDoWq#=$hskw69Z_iKJX$`Q2j1_ftP=U^jS@{pYI4Zv) z7oCP>-@GKl@tRc`6ZK;;w@WdOeNpWoT0FY4TnOZrr>!ryqVZTzq`6Mtt)#n1ND*;2 zq;@`n2jAjbHow6i)6;i9b;3Sd9c08KOc)Lhb;FlyyJsJ@<;xaBpOon9`zl)X(%Pl_ zF$knVNx;CwSNVh%J@tBa`UxF?A6@AXWZZ?=w7chdn6%44<=3OHL5l~hdt-EC_Et6Y>P{Or-%0-7C2ty?%Xs{gOe2@>lJuqDks8T$eBwkn zUP0!7FwjW0Sm(aIoWdV@49DK=_JD|e!IrA^=BVFBTHvl-XcN5@Pv?-Z{~Z_dYQ*V@ z1%y{`dOuJt*En1n& zlkh^nvoTl&V<2obAA8BFfY#kfp@jo4eMsYsz{rn~xh_>+11Q)lNDif@N!(8c4z$@8K$yz+=0$}Ea+;aSyukV69&S#1XXTME*T{`z7 zV;m2akif=dTi>?Fheg0WJ+;~Dnv&lO)RrNERQDoSDbw0Q&WA#7@FP$6Wf+)<2Z`8J zZOl=H8;OC>JS0RfQHxUgh|Ky;MJuTcS6@4h1Va)C3);fln5C8% zz5TQXJuHydZLby?w?755nabLdPLw)j!mS#A`b)V|vdyCu=`0z0s>u;Itjw0H48k02 za%Qb@4}F{^VITmu@-%H-Xoyh3s2~%$b$nV+ooS|H;r1(&RY_wBnOHA#JiE713%X}= z_6tnDZClB)?hC!8Z-@W^;rPU&jHto_V;sxMi#U%Mos9B*D?W^MWT zt>|K$eNl{Q%_oWb(-?=7+M1tu#q?}^a66QMB>0GCZpV-OSghl??=$nJRBFgEa{{$9 zpOS_NvPre3JB-blwC>lMdNO~ut}{zD)y<=rP=j^%S(X5h_dT3SavZ$3W8FbKvb=#j zrPDr-@;9D?(SU4!jB#k#^3{9${_B-dbq1GD?zhwH`dc+IS@Z7of@sH{WaPX6Ig4k)AQKArBuL{!6nNPc%4+p|z_Z7%wUS1%~<`L!f~_I3j-+=b?ux!l^5|eQ|YQxn1|!K{}r#W-fye-djyrp`VSazZvKybawN1h~za{AOPW$=Q}vlGOVw zM~;z@uzm-BTE&hirn3#HBX7kpJCX>(kr6rEKH`8W&K8iaGX+olGJ~&g@A|OuTgUmA zLBkhDpnM7rNGw22Ue;@UA$V5bD%aX$Md7olkQiI3PqM0|c9Uwq@L`yfo`L}%7!PNbL|$+oZeYBL0c zGjXj!WotD+VO78g&)Tad{o_eCs_~w_F6NslU7q%JbG-&6o!dm(NWihm=t|>hk7H`q z1V#M|{xV%GJI;8rKwZieOTG7XpQKk2IYXbn!AF*Y07(K^Gemucv0X&qK=M*B-rH}1 zZgix+4VGhDf>x7&vuDvr`BrTXw=~s|(ZR!-tR}c=s8MRI@d@Er@nq?o6syG_D!nYb zzd8kI+Go8vd+*{qg~ngW>IC>4kekhUQZLqXtbO8rmOqM#d{-}W`~gpBj(d5wJE|wy zI_K%DybutDXrnftZSsZlc00yHnwd!&|ni!{#*aLLGWC?=O*47o5$m|E&@Nu?4-at_J&WgLf>9z*6p2`w76op+@r$3Jlqo3`>cnhnW&)}f}!kW zMl7Q@g*9Q;Dd}jT%z{zo#HV`Z^j*mWlHZFPs6Jv%NvXvL1`>Sg#{%gbu^A1YPId;D z%wS<*I9H8tcJ7ushA-zI%zFO*JVW69s9#7(2(bkXg8;69+^aRSdL`lOsE7h8Z5taHXGtkXWSbcJ#zX)MJNyOy{8N zei}%7rJI3hni1*u_a_p%$2drlSQlTOslv3e#9K`&!`i(YH0VX{~COdhdF%T1>7AWIse;t^iJevHt#-qyk0`y+;hjOCxiPTUAKih zVy>8>Emtr(s-3dof?WxNfoT^0Awn}7xoPQ(`tp|Zb}8V<;#J^B5#CfUO71i-B5pGV z2EscyIoTU=qOvcCl^!Fd&I_)h>Tv5jvfCttP#_+woHo#Zl4q%PpYHTtu}0MtMfmIq zS9zm%{0q<1CyU}D)!cw++0w;&JPTeA`dV|9z#PZy^$NVG8rh}kYD@={ji=^ zkVJ_U`_B_k%;w@}8f!#~4M-8}PQhYX(E=XnM~yo(#MBj0HkPAX{QU7$7^SR@4Yb-E z!Y_`d-H-8DUw)hRkQE&yxQ;wjQyzI)v?RX00f?-^g}U8Yx74-MN3xme5WpFDS~j4v z>Cb!8g}><5=7!F!3TLKM`#QBpKh_}Z4r-5%Z5pHa(Gw2uQNy2FY3~g^w8Z)8-8~=% z7`HOpaDZEzP+B-w!&0>O#}7-aD}jf~Jce3so`P@c;A*W2mPJt$KANuWc(Md9A>$JcEcLCaMu{cuC6XANil#$@?QteP8zWwM66{PH|>>R7-3X)w(u zw2#uKV(r_z+co;S{Z{VMByIQI*L8G8h9`8P-QB6FNzI!2-3MJpJ{O! zEd<%C=+*Oad2XaQVUG1Og^gf!MzOT;1TkMo#TEXh&JYIl)Kwg1yC}T0@Vk*sstpBe z1E&)?^V6QhwygS;xj3ve3d47~RGyRtWKMtp7V*PoTX4xv&LOGl*(_If-Z~ESt&L`* zBW`arkRLmG5JT2(`SC!{<~egdo|s(j`5tAN~ru!bEUn6(oV_6psN-$@pXkK^@a z#Hg?i1KDlh$i_E^j<{>duL?^H~ZrSx=$O|WWHKo~`))0T> z&MNTFu7zk!gFSqEr$p-I0;xM24yNXT1AprUid>Ef(8E7T6iP<-7WEp9sex{k|teq7o{ zL?rR)HGda29OaK=Yi+4+X@3r=%Z993XV0HQ1|NMU1p={D;fi0A<&;NV4b!z)w?}H! zZj@(Zyyn9)O+Yv12dCEtb}}i|#=Iq8E7rl9N|SH{k67d?JTd}tKM0aVgGJK- zsAuA}Ydt!YQ)xl+t(FrSeMb{(&_C8`JDhnq3=A`1%A*7wN`uCwg~Jjg<(7TpdWKFvcv;H*G1*@e+~GdAc1#F`J(s{`pX7uMh)7z5P%33_-Ie z)VfzSBsS3ym4!C`&%{_a0SVIuzk}2Pxim#{dBsSwcO*@T&ev16SatzldFF6Orv?WT z?ppH80M27*ZnJq(w70YKkGw-dNUC~q@+-rhhkZQz=r z(^bq$6ek){%Qx8DzJFW->r|hw1VC?d>IE{?Da< zOFyCYv6|)&ctUzZmskjb3Yhc#vkf6-&621E@MG_@syzxO3&CyWt;+@fd!+nD$l8gZ zK)uu=eIk4Ly=z;~5Y%os4v@nU*mn14PMO~ibu0F}?!qg2(0yh}Kq^>Zbc~#V*k!$l zND}S#k{2pV>isB!9M_>*YD$T1Vr^^-Ih%l=+`}g#3|(yU)77>)yB)11wrc9gF{GMT zR9O(^9gLbvZrL(3#~PcCA z1K;tfuVm5iQ7|N%oa1wPlO2Dfy?np3czW%%6ByJ`4QHM!G&(YlfExe;=QP}U1&pe| z-1g!}=Ru^3jWM6J;=-KV81UFNE-_X^p1w95X0itjW8QFm)=O|(vjKXQwm&ScA$wEq z9bVoq*am$|js59Ozhz}hFs;5?@L}6}Eo2lK!l{Lone0R-jis!G<^er#30X(%Tt9PBXEZxZ>h zb8;AlPeKw*-htqWRYY}S1TT)|Ow>coQ!`Bry@1y zLz%zXn!knq9_K%0D*f?6=A!azmhIGi=duQ~u?Y+SXPQf%jaH44^?pvxRCfY;(88f) zE$lFf(D9F7Fv;X1={E$I*R0>P%ev8Vc{4Nb*C|}*7(&*?BoaxFw^FRHiKyuX-gLIg89epvJ-GtB#nTM-(osP^it!MVRY6zDt*co=hXF9DY-WGGXh-x3ty%ZLF zv>qm1g|LorLFN=4j~pFnrOIAK>L zFD6g7UI88u5D5{w2Nvw6*H{xp+4+4r^}Xy)`F+&-Df2E++v+eic18dFPGJ>!V9&}2 zeDP|#fyK$+9{wG{&9G?dm!+Pv8x98-A z!fS;I0|`Wt+ml}C0(Lpng_QT0g4ozZ<@j@raLJURPd=A1l-w5<~XvKV_)Bx7#>e^Cp;|x`PMFL z40sH!klNLMj3F=ob~ZMrLg8-R8f99YYX`~Xhl$94zy)*Y9;>6TJbX`|N+Prhi4WYdZ3K@J=Jl!D<;Hz|lDz60Q z70jzz)Q;wjRvlXp2yA~g8nykE;_hfh484)(ip)W9BSQ#}Yg2^vO&iczQwG~124#SN zE}ga&c}WBfKcx_qLKkDhZ1ZWx673z#1>-_aOQv)!fwvG^I?r!nai+eD;SQc;zR(j(d4|7Zn&?>S$K{Jl1{8vRVjgFJ#Y$EUF^(wH0&=K3Aa zID)B+fuVW}yPD#ig-T4M_}3-NC+63sfdR9WxOG^1$U)}c`7u&`o#fN{o2n2VG+uc- zb*)4ioN?+!9*p9|K%*pirL!UA@p&H{P`5eEa1zh@jU$-Ce56m(1m5lNdr~Aek*$aH zq$`@Jo6PG#?NnG z89X1Uu(cSzwv2o0`IveT=7lp9a?XEu_#^E0amw2q4swxEo@JvXa(uMeZadf%DWDQ zj(^C5HJOgpS`t=2n#ERHt48k>1i^M^0??y-9ZHxMuR4%Jz+P1P6&+#T34o3*UAaj#pJN{P}FyZ81j!j9jtSandz>6K-8kHeHEx{+RNqAM$qrjIoznS z2S-X&LZu?v@r` z>urz9teq3w{%Axc#&!qJ7B%v@p1&{YJ8P9=x##9 z`cry_C0`d~7`spI>Ct=d{x(e&>nBnu(pkbo3r|0bhpe_g@gq$8KEFXcd7T2ZNCCA+ zE4O82U{5k)i2_)6-!@C-)U5>MtH$W#jZGLYY5%JhKO>+}aPEKXP?hbJw>Qmhc2yd_$dHP_E z=TEX-e#KKep_9iYt(zUbS3mGK+X6~T)2?9pO>!WBBkP6vAXG?}cILF>Zf21SdIPT>ofU<~6mhd0ygPSR!i&VK6~xI4CV z$H#Hkd8*b;V1Li}N$isnh)~50p(dQNAyz?^*Hef?xZ;7`A2+(Kp#WN+buoVHM&?tV zAp$-76}}FdjIf!&htgtpftw;#+4kQXEc2Qz~KQE&&1sR5eE@0_DrapL`7)f;4=tbY^4X zEklM%k`5m}$5#Vw(V?)KS z>qZq8Lfb+@ZCy=e7-}l`K6PAm5CAPi2iGS|a)vesy0R{w!`AiqZ;tkeV#AER|@a+O( z`CiM>InV!T>^$7z`uep$dhccQA)f)T=5@_X{U&pFriUg!J)GkdK)d%kVm_h%cb)|P6|z(Vho5zp5>Ve-= z#i#hd1<$70yob3?dkL+MCaAG{-mv=>f7m38_dAp1SMK4o&1q+55ON+N_V#716{KSF zeeW3#v~eEjs=opu zqEl$uY)&z`JD~-m@Dze6Pj!>78=my?GT_nE-FQANm)&)QP?e_4nV_2!P3aP%F`EOu zd1rOaeU!7<84=wUiY~p5NHZiu%Ly5F@4XCbB+nOV`P{BgsV;Lesl5Vg@BLTnSga+p zIt`7x4R3dInlL**oG}r_Y)-03--tq$-%?iqCOH{f{g|3Ul+oR5&*1}Vepva_9?PhB zI?8FcTw)Bvbc%yZx;FH67QXqjfi{?DyME?`goHK?{Mppr$?C1_vVmuMch&j%eI4Fy z(KMk{Qpx0>gk1jgX4r z`I7}u*b&qxE}JB7^;9B=*UN<}^Wa>#Z=W;vwH1(3$JE!bRiCn79BmP|`Es3Y9NKzK zl2!5EA>_xZoRGte&s=vW97cw{d`-BJ?1x@OinCeSlbD?)}pbUV_MhxMQ69APHo~r+dk|X;D+z z%UlNj;IH2n4O(h=Hc+jcp=-{@K3CD~mYL&i=Fq!oQ!6G@@N?*P3v^HK*SxlW^Q{YK zv%KI7JD5TJzzu?~Ds$lincDkw;!24btGA>Yw>potw3k18_Hy=Bp1#4H_F)J(ikOZz z^Znp)%caLVU8djlU=W}}Lf5X+?NTY+d*?UISW9rc!vO0X{=A-xwdi(d6Z5g_x?Pb` zbZDfp6xt?FAyG-8Huze)to{*?WK-JG^liF}nugFIYs+pWOS8^W5n!rt4c;m!rq=NsPhK*1w&w++a6X4B&Wt`S$&gGfY`FtIEhp;d z>YEJ@d z*JFebXXECup#v7U@i(zS#Jg_APK;fq%9o00WrG3J=~oi7bgM)6I|(l~`lNpw z?a?|qm&~UO1^r0zOlO@t(|bk{@0{|$5;(o8O#w zZ2QT~V0vr-)aW}5PUY6yA1_S{OOw^uJJ}3tlSc?dZJe6%pt>7B5Wr=`QJ~b0KeE&O z3ZB2+9NVtY9#_O{#2dc2ntG$bQE)o)H6QNm1HL&Ky^ivGYTvf4pPaUN+?`g+M*X_l z?o%AqF4G*!x%_yG{7v!;UD^r>IcYuD>2K`wIp(SyypKq;h3Y|*6j4I?aeYPU%&wxI z4B>5ulPBLC;UqO3AH30Q$8j?RZ&9^4Y^atTC>KiJpC!686emQFWPj^e$TjQMOyx3I zmp@VIz%L%jn0@s~Xp33LJ1qFAE`-A74eD*KRe$_4V;muVX^YzV;@w8oVAuYx#&~jL zCj56zp+pO;&a_&@@&Ru`M66yd={J&aetyF}OwdZH^(w@HKjKyUI8KNPUqlV=`M1jM z>EvXJHC;&Qh9lyx2V+`5psz}lgM#g;(~QkxOSNb?ht-^^ASA}j!nfrj6dAE^33II;>r7>K)GE>U zU)W*wpGw$%POoM28c1VLPVQjN(jAXR>z|(^gLHFVSbNrZwM}*)B;!xCh`}W095c>^ z__4v{$tbj#TP7x4plDRF7_wTJwZ&lYIWa-Q+m<)-cqg-abG_etiq<0X=Ww*0R>=t$ z7R;OvRXsucg(=H@L4wzKmet`p9e99bzSY-+fDTM$%AThAn9R{@);}amN-oXT=4Q?i zH_j|VGI}p&EtyHMK1bJ;Cp2(#4){EoV}8@rz+(dO_Fgjbokf}da7 z#3t>#bOTyRE&@4N%20o>j#BTxh{wPXj?Xey;k-W|a>C}E=qTmTxJa~K8?72qq z-hJPlUQ11RX-C}{5n-Ed9$}jv|0_gQ>4t$-@bH{`kY~G|)NQZr_?+cK+L~IzgHAt= zAu?{pVX*d6`FK?tXL~OdWqt{Bw7jyQ`l5j6b(5XQKfSk5shk;B>4a!j{U&^Siutl|TQ%YMsgDUuLUN_~ z0>tMBqa>J8G4Qt=+b#P z7mL88cjM&kZng>Nf)=Ez-Tbo{AE~RIeDPiJH=N~e4DcU2>vo_Dc&+fsrQj==pm6+X zISikd(iBQn#S+cawikbMSAM&7(rovIsRvVop-0gd$>edumdAqjiYA^E_;sWt28?7t z!E{@n$)|}}H8}oR$ZAsvlK)Nv(XHO|QQEZw*(q$cHuD>`-NOYuhv$*kU{fzY6N;ve ziM{xeG(ktoHM94^isY~oIZ+aw-V~xa8(Qg4G-a1Q!B2fu`$p%y;Ud{ahmI%n#nF0Z zyPpr6U!rzoIyNv{E|Uz!y4x#E<#(NmQSOfv(mvMi_vkMAa%jRATsp*c_WKJRP~(q1 z{63s5yj1&-NH3C`WZPc4=mHhJOiUK?6ES^nMJeJwj(1fBeUi2pmih$`V>i3hYd! zEj*V{#l&oum0%i-W8YreUmr-IrL)yM0s0L`Q&SH{(+WQ?*P>_|W_POZBt;Y!V%aDZdk!KPmHABJTgz7cd5?C3qd42O;j5s+52FzX&YzR?2A|O{9+Hj` zvgCbKZMs?byv9c&b3hm6|G=OcAIEW1yB9WeC{b85 z=4o+7t=wg|4S>_#!K7YW_kOqZMjL$OZ1(|luhdD{zIH;zd;*tspM2}C_v+T%gWT!z zFry}X9TQ+~;Z8Y9q^(}Ri0Bl4>U7Ytv3QPWW#skRus^*Cv46T3&vd@Obwzm6d@e?d zI5@kczzyqIC>_1p`zTHBsDpSqQ3Xgre~<`Z-l~b9&cSKM>%m02KG*lL_br7`6p2Dh z7Ugu#N6Wr%?k@ePWzu=Zq^rDDtGB8|h?186yrr?hnEobgLl8S>Y{vDsVDc|g`>O^# zeYk>qj^H#jYYL^~LfkA6-?DoiF?eB>tE;E(%sYK( zh={BodFFC{undh*0k}YU^wa5ou`M9LH-n-iv+2Wtt7>J z-PWqR5AI})R(CBlX0QriV_!rCp611RM+Ov{cm8&Be^l1wlJ8p8%j)xDt=Ynno&6Of zK6OkC^ZsaFAFK9NT~ITbdUGy$4jw2Niyv}93515al9TOP?(s|WI_k&VudJ{unCWpX z57VrV&5DTB-Fk?QEFD=kATiN%^-I#L^~yqbD$RSE-G|;Ka~(f?SmBi>=pV5|m=z=! zi_oN9qhGP))^#%5(~_@kcr#{QZ~e5hg{AhkZnFZtHM&Ee=-5VJ+08@>VL=5N+^(rg z&&BdsX3Vc8{GE2=E~&Ta=(*dh+;|UF=k~Cy?DlMhy_4UgCM_Wr!#N~T9&^(~gOG+% z%6%kdsfTK=mW)Oof`jk=dm_b%HyR}nW%7j!zL4hydBBqb_mofWgR!*X)RPTCZ#*v1 z@40At_%%==eL0pSBuJ;D`56*OsCI#K?jU8wQ$GeSV%--BXG7i-*|q(8e=F_>-^bE1 zpn}-SHR_wI75rFy4Kp?CXRFpzqy%9B`jSqk+|G%6*ZjS4taG?%j%319VnI@}v*3=m zxV>c}zx_Z@iZuI`VfExrM-4A=Q-Z~9;#ZHfS%lKb6n8XpyRo~Rs{ z&jRjWv`f-UxME)TDG6acoRMc_oukPyieo*Md;DoC-<(P5)GQ4j(}>+Ia5 zZbOObzq%FVq`R^1ySr{({cw?coJOsH>TP+BLdko%xdB7sgEEN}{!`C{TUfd=8Hw@B zYV~SCf%!lI-aM4CXThTrzN5d*=>Y`F&wK@AA_ji(vA$vlRrkZCZ;tgsB)}qROvn4q zktH}d=!(p`MtBfu#V8ZKBS{oSlh>45^Vc#>r`op4i^|{0H=Pzo8ada56R}cUl-VMF zQ%BhcF^If**9>3M|A1j%PaEtU47Xg|wp5T=@N)tPjFB4;hbTE;!2ONLQnaw_VX=im z7DK~*hF3fiUk}j_ZexW6X?2~KrfH8p*C`RD{q?gN){~P5T`}ljBPfpuK?bmvx*Ao~ z&nSnwQhd)YGxo0IwtoCUMODx1xChshm2od{Y*2BM{;C1J(qw_s%O4D1W+W>%fiOnG zMjW{fH#W*U2Hh^706)hEXZOJ1iV|R1{@~(@?&nTP9@ds1l}7<-0KY7uIlF;(r2p)>=|LL`v%uM**+11vtt;3U61NNmSi$ zU{3WD1a^on2XSQmN6*cZ>Q-Pd{^3d&US0t8>TQdE%Zx8P!3=DF@1&XuJ2TDc66uSk zv?d5U+nHeft*NQ*jtBP*#J4{=Wq6lrcQFk_<^n3$RU3I>sCNoqJiziDr0W6sG8tpK z^72PVDD%jAfm&vo$p2XQSx|EKaZrp<7G*btP)c>aHo?)(4_NuzwmnHVN8b)T5+ZS8 z28y_;s(F{98(h1;o1$~WSVYVxQ2r4`M!xYoE;&A7FmJFHL$lC7RA-P}(-mhBZmpseUCSCsS}< zLM~G3(%t@zb4-im$h$4P;BgAkO`-Rg#_g&;%&Q4?b zU~Sj#VP@(nwf8G}#?~^#LNjpH*J0JKx2cf5+{NF^gH>pZ`gm4TWriFegCAjDFJ+8Z4BwVxIqKyj!LPnEh3`X>fpP%cB5*#~8F zzF916$-%bKZFX=fdDr1e#fgU+XmWP*WTJWQFuNbVcs?*rG5LOQt+}#p7O3*<0oq|c z8=>-i)Pb5FZ^y#wP)OMBMgTU40?fJx?~)<-8S<^?FHF>$y{?4-+VSwpGW)7(N5t3i zfprNw7o<=tbli(FF)^~D0^XKe-A&J9&{cO|O`PoeO|-^rQ@&8Y3a*VUd@+x`;8z?j zzAA=={fISNg>QJujn|ooWsrhkygkSHIsW+qiPK(Vc`PKP)xDOil1g)9>vT_dON6Kb ztW~d4*?^Foj=1fqJ_Qp81DxBtP`~&3g+tH}qVS&xe>Or<2rMigOP@QVyX=N3`EUaJ zFdok`&G+Uoc&IpgD2XRy?Z_QtlVm-kt6ROAuaat|@)+K|Q#`Tn@!U);qgfYKO~LX( z;&9K%D&nhHB!m&Rb>#N6>#80yYLH5Yr<9#a*}ZkituRV5ntY5d6YUIpe7o+>n|1OYr+!PmG?3ZM;usN_aLTZBx$q=Jj)p zQmfmlgniURf)X#JEy~qCrqbn6h%9k53~1sv$2BjS2FeVDoOH}kxeLK0)3rM1hxO@J z9@4o?n2DtIVx5f!p!dGAT`s;cAXFUUDFJY#ULIEAo5@<;;uTpK;gr1l z83w`WX@IKD%vIhAT|XGQ+yG_OgVt|X0A%o{jku4n>j z@MWAD>6w4GVNNdyx$zEG!`SHb{^wP1$dwZZdLJ3n%p?-~X0~e1WCTY;fRAE(Pz!yKaQHsN|3cn_ z5rz?m%I8IXf5vfJ!e;U9^Yh0<*PmhGYnJFR*7BUwK#BseK@5;FTBm0L1lm-`~iupBQmrK<$O=lB!yn(t@aH^6uLk14Qz)s_!bgw_S|+x@P8ed;)_+p&z+`!ED~ zD|lNRFJLcvlRsASb1Vhce=I;IshysT16r%LJN80qapg4zLGykefz?Qvgm!nyW+@k5o%7I zrHBdDj|NoP7I6*mZ_xf1+;x8^ZT(7kYee?G(!3l}aiJ&fU95J+r_V)N9|-1$CsGdp zs_=JGJ+VyT33gELo=BP>f;C3GNClg`?PFpRP0z&9K%?dz3@-{JZ`tnfYv~G4%Gdi; zd9l`jjn@e9!W>pNj19iQvvk*$z>T~;)5|@1@4FuND#=V^fKMAVwG6gz%%OJq!%CkM z?I+~uo?u7(`SK9<6C#{pBiT@R%w`ribeis20T|bp)LB><22(-lp-WCJ;6Va5AY{x= z$vjj%cxGCD{AtU`5B~k5Qn#oJ(H&S^tW!!_v?2W{C4G4~lo0l{^KV$jF)0Gr$ zNubyVoZzq)qYr?sc7E&bFP;H~m^t2i;$+g4{0BD{gGCoDiyzA=e_Xdqa?o6HL;?F? zbwN^w#>UV;>$n%KWgfRZQ9p940 zj(%j3m@(Q}ShghvT|qeUj8QJ3HB$3l3LS_3{8~4X{vIk_^07n;mZuwt-%UI9u1DbXWXV?zbp-|+&^^hUIg28@bv6u(5Rt*GjOCnrAwR!8H zDfwyrI8lSk*BgT!zwi$r9JRbmS;L7w7?55*_=hg%?(9d=sl2?qR^NZHl1W7rR+c%=gblZ6M~G6I+)|jA^9oG8RsXMR zJ~+G#>*^q^=kH-ySgkp)60xLC-CNw%^x(TslLfct7D(oNAyNe$QtDZIuu(3JA8taV zDTxZ^nE|~2qcEy&dNiJVd^cdcpD-!KUOUahc_iV0-up{t?fOGzr5Z5g0Ib;Fq!Nze zS6LgeG9t=faGN1`h=+S`6cBTeEk558AEbifPUH*^HK+rWn>5J<=As_}Y+6fIMq+wUb{1J-LZ-ka{ zY}P<(WOt*}i&{IlpxgYKFhJmc_{hEhoo0`NZ#|1-HzQ;75(5^I-u;KyG~7W* zR1Oy5O$PnLdd7oiSu4v&lD#}tG*=CXLD{z7w5B~=iajWbXd_CXMKQnpSB7d~-DZ`e z9)O0H^ZWry6P?iv;4J<=s42oYWP=3dRUQ@gHyJGLVX?td-(B1oE9t=}6(?ZVGCmJ$ zxa5iRE_!1wlac$K6h0#KNwI#Gx+3Ap<(>E!ewFg70g`e>!q?=55BIvwl!NA<>;JoZ zY#-@KjxKGG+T493lq_G4As08)mms2LNos%6&LH&NWrn@%>aE$q5XB~Q z_xVS+`2?`HMc(_6X>!687mP^T7b=9l<$?oKcn>$Xjp$;5xYFfIDzf|jM_Q=($l3~b zRxgV=j_UCyqHL1JRlBRtkztxupl0?bCRzfJv7A8eU*HoSkj|oHxeb`c+SwZWk-N|J ztWUX*BAIJ@)c+6Zpw#=nr-Szwsr=p=mE&GY{0CyhNwZAbMmX#dX)Mo(^k7CwE}kb^QPGfDY$> z1b#k&Bsb@nwa_WvG?C zwJz-3LoqXA!|mX4byCL~N#M9Tr91&P0b&9jZb|tk5u%VOT8@-cAo=)$Lmb>}1;wOU)_5;4a#4snJ5;ZmSqi^4P>pdjyI*2DDrvOD|gL zNCr>)Yw!@l0cau)=OJ&(PU!S3_62s;Nq9AyvxkT6??pYDR0;8O+`_4w*}!NZJBYgZ zIpPc{TpimDbw0CZbkfMQB!|Hw^I1dk;C+mF6uxnf3IQda_q*)Fo!aEfjT3Qvr2W+E z_9Q9&Z~!odNN1fm@*rZEY~L=o@T-;<&=A5|?o_}4<$#5KKkHQnedm3WGm?%qn!YXy zESS>ZZe&MCDrQj9iUT)ncu!+1;D65`Y1qVE07nv$BCGHEJ@uP1g(8;ag@1{&aB>@t z8A!^Nx%s4!Um7J6Y!S(OQ`df5Nd=G;8MJF-8CoJMp?54o=(E|lDz XzFW3 Date: Thu, 9 Apr 2026 09:47:27 +0800 Subject: [PATCH 03/52] fix: include home_content URL in CSP frame-src origins (fixes #1519) --- backend/internal/service/setting_service.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 7d0ef5bd..5c90317d 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -333,8 +333,8 @@ func safeRawJSONArray(raw string) json.RawMessage { return json.RawMessage("[]") } -// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url -// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. +// GetFrameSrcOrigins returns deduplicated http(s) origins from home_content URL, +// purchase_subscription_url, and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { settings, err := s.GetPublicSettings(ctx) if err != nil { @@ -353,6 +353,9 @@ func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, erro } } + // home content URL (when home_content is set to a URL for iframe embedding) + addOrigin(settings.HomeContent) + // purchase subscription URL if settings.PurchaseSubscriptionEnabled { addOrigin(settings.PurchaseSubscriptionURL) -- GitLab From 02a66a01c3da45d29c37080abb4b87cfbd40bd25 Mon Sep 17 00:00:00 2001 From: ruiqurm Date: Fri, 13 Mar 2026 23:38:58 +0800 Subject: [PATCH 04/52] feat: support OIDC login. --- backend/internal/config/config.go | 160 ++++ backend/internal/config/config_test.go | 54 ++ .../internal/handler/admin/setting_handler.go | 281 ++++++ backend/internal/handler/auth_oidc_oauth.go | 865 ++++++++++++++++++ .../internal/handler/auth_oidc_oauth_test.go | 106 +++ backend/internal/handler/dto/settings.go | 26 + backend/internal/handler/setting_handler.go | 3 + backend/internal/server/api_contract_test.go | 46 +- backend/internal/server/routes/auth.go | 8 + backend/internal/service/auth_service.go | 3 +- backend/internal/service/domain_constants.go | 27 + backend/internal/service/setting_service.go | 458 ++++++++++ .../setting_service_oidc_config_test.go | 103 +++ backend/internal/service/settings_view.go | 33 +- deploy/config.example.yaml | 40 + frontend/src/api/admin/settings.ts | 46 + frontend/src/api/auth.ts | 25 +- .../components/auth/LinuxDoOAuthSection.vue | 12 +- .../src/components/auth/OidcOAuthSection.vue | 53 ++ frontend/src/i18n/locales/en.ts | 66 ++ frontend/src/i18n/locales/zh.ts | 65 ++ frontend/src/router/index.ts | 9 + frontend/src/stores/app.ts | 3 + frontend/src/types/index.ts | 3 + frontend/src/views/admin/SettingsView.vue | 386 +++++++- frontend/src/views/auth/LoginView.vue | 28 +- frontend/src/views/auth/OidcCallbackView.vue | 234 +++++ frontend/src/views/auth/RegisterView.vue | 27 +- 28 files changed, 3154 insertions(+), 16 deletions(-) create mode 100644 backend/internal/handler/auth_oidc_oauth.go create mode 100644 backend/internal/handler/auth_oidc_oauth_test.go create mode 100644 backend/internal/service/setting_service_oidc_config_test.go create mode 100644 frontend/src/components/auth/OidcOAuthSection.vue create mode 100644 frontend/src/views/auth/OidcCallbackView.vue diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 9b430377..117d4293 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -65,6 +65,7 @@ type Config struct { JWT JWTConfig `mapstructure:"jwt"` Totp TotpConfig `mapstructure:"totp"` LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + OIDC OIDCConnectConfig `mapstructure:"oidc_connect"` Default DefaultConfig `mapstructure:"default"` RateLimit RateLimitConfig `mapstructure:"rate_limit"` Pricing PricingConfig `mapstructure:"pricing"` @@ -184,6 +185,34 @@ type LinuxDoConnectConfig struct { UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` } +type OIDCConnectConfig struct { + Enabled bool `mapstructure:"enabled"` + ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等 + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + IssuerURL string `mapstructure:"issuer_url"` + DiscoveryURL string `mapstructure:"discovery_url"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + JWKSURL string `mapstructure:"jwks_url"` + Scopes string `mapstructure:"scopes"` // 默认 "openid email profile" + RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback) + TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none + UsePKCE bool `mapstructure:"use_pkce"` + ValidateIDToken bool `mapstructure:"validate_id_token"` + AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256" + ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120 + RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false + + // 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。 + // 为空时,服务端会尝试一组常见字段名。 + UserInfoEmailPath string `mapstructure:"userinfo_email_path"` + UserInfoIDPath string `mapstructure:"userinfo_id_path"` + UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` +} + // TokenRefreshConfig OAuth token自动刷新配置 type TokenRefreshConfig struct { // 是否启用自动刷新 @@ -968,6 +997,23 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath) cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath) cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath) + cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName) + cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID) + cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret) + cfg.OIDC.IssuerURL = strings.TrimSpace(cfg.OIDC.IssuerURL) + cfg.OIDC.DiscoveryURL = strings.TrimSpace(cfg.OIDC.DiscoveryURL) + cfg.OIDC.AuthorizeURL = strings.TrimSpace(cfg.OIDC.AuthorizeURL) + cfg.OIDC.TokenURL = strings.TrimSpace(cfg.OIDC.TokenURL) + cfg.OIDC.UserInfoURL = strings.TrimSpace(cfg.OIDC.UserInfoURL) + cfg.OIDC.JWKSURL = strings.TrimSpace(cfg.OIDC.JWKSURL) + cfg.OIDC.Scopes = strings.TrimSpace(cfg.OIDC.Scopes) + cfg.OIDC.RedirectURL = strings.TrimSpace(cfg.OIDC.RedirectURL) + cfg.OIDC.FrontendRedirectURL = strings.TrimSpace(cfg.OIDC.FrontendRedirectURL) + cfg.OIDC.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.OIDC.TokenAuthMethod)) + cfg.OIDC.AllowedSigningAlgs = strings.TrimSpace(cfg.OIDC.AllowedSigningAlgs) + cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath) + cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath) + cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath) cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) @@ -1138,6 +1184,30 @@ func setDefaults() { viper.SetDefault("linuxdo_connect.userinfo_id_path", "") viper.SetDefault("linuxdo_connect.userinfo_username_path", "") + // Generic OIDC OAuth 登录 + viper.SetDefault("oidc_connect.enabled", false) + viper.SetDefault("oidc_connect.provider_name", "OIDC") + viper.SetDefault("oidc_connect.client_id", "") + viper.SetDefault("oidc_connect.client_secret", "") + viper.SetDefault("oidc_connect.issuer_url", "") + viper.SetDefault("oidc_connect.discovery_url", "") + viper.SetDefault("oidc_connect.authorize_url", "") + viper.SetDefault("oidc_connect.token_url", "") + viper.SetDefault("oidc_connect.userinfo_url", "") + viper.SetDefault("oidc_connect.jwks_url", "") + viper.SetDefault("oidc_connect.scopes", "openid email profile") + viper.SetDefault("oidc_connect.redirect_url", "") + viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback") + viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post") + viper.SetDefault("oidc_connect.use_pkce", false) + viper.SetDefault("oidc_connect.validate_id_token", true) + viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256") + viper.SetDefault("oidc_connect.clock_skew_seconds", 120) + viper.SetDefault("oidc_connect.require_email_verified", false) + viper.SetDefault("oidc_connect.userinfo_email_path", "") + viper.SetDefault("oidc_connect.userinfo_id_path", "") + viper.SetDefault("oidc_connect.userinfo_username_path", "") + // Database viper.SetDefault("database.host", "localhost") viper.SetDefault("database.port", 5432) @@ -1572,6 +1642,87 @@ func (c *Config) Validate() error { warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL) warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) } + if c.OIDC.Enabled { + if strings.TrimSpace(c.OIDC.ClientID) == "" { + return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true") + } + if strings.TrimSpace(c.OIDC.IssuerURL) == "" { + return fmt.Errorf("oidc_connect.issuer_url is required when oidc_connect.enabled=true") + } + if strings.TrimSpace(c.OIDC.RedirectURL) == "" { + return fmt.Errorf("oidc_connect.redirect_url is required when oidc_connect.enabled=true") + } + if strings.TrimSpace(c.OIDC.FrontendRedirectURL) == "" { + return fmt.Errorf("oidc_connect.frontend_redirect_url is required when oidc_connect.enabled=true") + } + if !scopeContainsOpenID(c.OIDC.Scopes) { + return fmt.Errorf("oidc_connect.scopes must contain openid") + } + + method := strings.ToLower(strings.TrimSpace(c.OIDC.TokenAuthMethod)) + switch method { + case "", "client_secret_post", "client_secret_basic", "none": + default: + return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") + } + if method == "none" && !c.OIDC.UsePKCE { + return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none") + } + if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && + strings.TrimSpace(c.OIDC.ClientSecret) == "" { + return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") + } + if c.OIDC.ClockSkewSeconds < 0 || c.OIDC.ClockSkewSeconds > 600 { + return fmt.Errorf("oidc_connect.clock_skew_seconds must be between 0 and 600") + } + if c.OIDC.ValidateIDToken && strings.TrimSpace(c.OIDC.AllowedSigningAlgs) == "" { + return fmt.Errorf("oidc_connect.allowed_signing_algs is required when oidc_connect.validate_id_token=true") + } + + if err := ValidateAbsoluteHTTPURL(c.OIDC.IssuerURL); err != nil { + return fmt.Errorf("oidc_connect.issuer_url invalid: %w", err) + } + if v := strings.TrimSpace(c.OIDC.DiscoveryURL); v != "" { + if err := ValidateAbsoluteHTTPURL(v); err != nil { + return fmt.Errorf("oidc_connect.discovery_url invalid: %w", err) + } + } + if v := strings.TrimSpace(c.OIDC.AuthorizeURL); v != "" { + if err := ValidateAbsoluteHTTPURL(v); err != nil { + return fmt.Errorf("oidc_connect.authorize_url invalid: %w", err) + } + } + if v := strings.TrimSpace(c.OIDC.TokenURL); v != "" { + if err := ValidateAbsoluteHTTPURL(v); err != nil { + return fmt.Errorf("oidc_connect.token_url invalid: %w", err) + } + } + if v := strings.TrimSpace(c.OIDC.UserInfoURL); v != "" { + if err := ValidateAbsoluteHTTPURL(v); err != nil { + return fmt.Errorf("oidc_connect.userinfo_url invalid: %w", err) + } + } + if v := strings.TrimSpace(c.OIDC.JWKSURL); v != "" { + if err := ValidateAbsoluteHTTPURL(v); err != nil { + return fmt.Errorf("oidc_connect.jwks_url invalid: %w", err) + } + } + if err := ValidateAbsoluteHTTPURL(c.OIDC.RedirectURL); err != nil { + return fmt.Errorf("oidc_connect.redirect_url invalid: %w", err) + } + if err := ValidateFrontendRedirectURL(c.OIDC.FrontendRedirectURL); err != nil { + return fmt.Errorf("oidc_connect.frontend_redirect_url invalid: %w", err) + } + + warnIfInsecureURL("oidc_connect.issuer_url", c.OIDC.IssuerURL) + warnIfInsecureURL("oidc_connect.discovery_url", c.OIDC.DiscoveryURL) + warnIfInsecureURL("oidc_connect.authorize_url", c.OIDC.AuthorizeURL) + warnIfInsecureURL("oidc_connect.token_url", c.OIDC.TokenURL) + warnIfInsecureURL("oidc_connect.userinfo_url", c.OIDC.UserInfoURL) + warnIfInsecureURL("oidc_connect.jwks_url", c.OIDC.JWKSURL) + warnIfInsecureURL("oidc_connect.redirect_url", c.OIDC.RedirectURL) + warnIfInsecureURL("oidc_connect.frontend_redirect_url", c.OIDC.FrontendRedirectURL) + } if c.Billing.CircuitBreaker.Enabled { if c.Billing.CircuitBreaker.FailureThreshold <= 0 { return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") @@ -2184,6 +2335,15 @@ func ValidateFrontendRedirectURL(raw string) error { return nil } +func scopeContainsOpenID(scopes string) bool { + for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) { + if scope == "openid" { + return true + } + } + return false +} + // isHTTPScheme 检查是否为 HTTP 或 HTTPS 协议 func isHTTPScheme(scheme string) bool { return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https") diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 2de5451e..b9660b78 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -351,6 +351,60 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { } } +func TestValidateOIDCScopesMustContainOpenID(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.OIDC.Enabled = true + cfg.OIDC.ClientID = "oidc-client" + cfg.OIDC.ClientSecret = "oidc-secret" + cfg.OIDC.IssuerURL = "https://issuer.example.com" + cfg.OIDC.AuthorizeURL = "https://issuer.example.com/auth" + cfg.OIDC.TokenURL = "https://issuer.example.com/token" + cfg.OIDC.JWKSURL = "https://issuer.example.com/jwks" + cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback" + cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback" + cfg.OIDC.Scopes = "profile email" + + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error when scopes do not include openid, got nil") + } + if !strings.Contains(err.Error(), "oidc_connect.scopes") { + t.Fatalf("Validate() expected oidc_connect.scopes error, got: %v", err) + } +} + +func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.OIDC.Enabled = true + cfg.OIDC.ClientID = "oidc-client" + cfg.OIDC.ClientSecret = "oidc-secret" + cfg.OIDC.IssuerURL = "https://issuer.example.com" + cfg.OIDC.AuthorizeURL = "" + cfg.OIDC.TokenURL = "" + cfg.OIDC.JWKSURL = "" + cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback" + cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback" + cfg.OIDC.Scopes = "openid email profile" + cfg.OIDC.ValidateIDToken = true + + err = cfg.Validate() + if err != nil { + t.Fatalf("Validate() expected issuer-only OIDC config to pass with discovery fallback, got: %v", err) + } +} + func TestLoadDefaultDashboardCacheConfig(t *testing.T) { resetViperWithJWTSecret(t) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 4cbe5188..abae75d9 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -35,6 +35,15 @@ func generateMenuItemID() (string, error) { return hex.EncodeToString(b), nil } +func scopesContainOpenID(scopes string) bool { + for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) { + if scope == "openid" { + return true + } + } + return false +} + // SettingHandler 系统设置处理器 type SettingHandler struct { settingService *service.SettingService @@ -96,6 +105,28 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { LinuxDoConnectClientID: settings.LinuxDoConnectClientID, LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL, + OIDCConnectEnabled: settings.OIDCConnectEnabled, + OIDCConnectProviderName: settings.OIDCConnectProviderName, + OIDCConnectClientID: settings.OIDCConnectClientID, + OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured, + OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL, + OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL, + OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL, + OIDCConnectTokenURL: settings.OIDCConnectTokenURL, + OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL, + OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL, + OIDCConnectScopes: settings.OIDCConnectScopes, + OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL, + OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL, + OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod, + OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE, + OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken, + OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs, + OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds, + OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified, + OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath, + OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath, + OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath, SiteName: settings.SiteName, SiteLogo: settings.SiteLogo, SiteSubtitle: settings.SiteSubtitle, @@ -164,6 +195,30 @@ type UpdateSettingsRequest struct { LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + // Generic OIDC OAuth 登录 + OIDCConnectEnabled bool `json:"oidc_connect_enabled"` + OIDCConnectProviderName string `json:"oidc_connect_provider_name"` + OIDCConnectClientID string `json:"oidc_connect_client_id"` + OIDCConnectClientSecret string `json:"oidc_connect_client_secret"` + OIDCConnectIssuerURL string `json:"oidc_connect_issuer_url"` + OIDCConnectDiscoveryURL string `json:"oidc_connect_discovery_url"` + OIDCConnectAuthorizeURL string `json:"oidc_connect_authorize_url"` + OIDCConnectTokenURL string `json:"oidc_connect_token_url"` + OIDCConnectUserInfoURL string `json:"oidc_connect_userinfo_url"` + OIDCConnectJWKSURL string `json:"oidc_connect_jwks_url"` + OIDCConnectScopes string `json:"oidc_connect_scopes"` + OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"` + OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"` + OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"` + OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"` + OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"` + OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"` + OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"` + OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"` + OIDCConnectUserInfoEmailPath string `json:"oidc_connect_userinfo_email_path"` + OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"` + OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"` + // OEM设置 SiteName string `json:"site_name"` SiteLogo string `json:"site_logo"` @@ -324,6 +379,122 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + // Generic OIDC 参数验证 + if req.OIDCConnectEnabled { + req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName) + req.OIDCConnectClientID = strings.TrimSpace(req.OIDCConnectClientID) + req.OIDCConnectClientSecret = strings.TrimSpace(req.OIDCConnectClientSecret) + req.OIDCConnectIssuerURL = strings.TrimSpace(req.OIDCConnectIssuerURL) + req.OIDCConnectDiscoveryURL = strings.TrimSpace(req.OIDCConnectDiscoveryURL) + req.OIDCConnectAuthorizeURL = strings.TrimSpace(req.OIDCConnectAuthorizeURL) + req.OIDCConnectTokenURL = strings.TrimSpace(req.OIDCConnectTokenURL) + req.OIDCConnectUserInfoURL = strings.TrimSpace(req.OIDCConnectUserInfoURL) + req.OIDCConnectJWKSURL = strings.TrimSpace(req.OIDCConnectJWKSURL) + req.OIDCConnectScopes = strings.TrimSpace(req.OIDCConnectScopes) + req.OIDCConnectRedirectURL = strings.TrimSpace(req.OIDCConnectRedirectURL) + req.OIDCConnectFrontendRedirectURL = strings.TrimSpace(req.OIDCConnectFrontendRedirectURL) + req.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(req.OIDCConnectTokenAuthMethod)) + req.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(req.OIDCConnectAllowedSigningAlgs) + req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(req.OIDCConnectUserInfoEmailPath) + req.OIDCConnectUserInfoIDPath = strings.TrimSpace(req.OIDCConnectUserInfoIDPath) + req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(req.OIDCConnectUserInfoUsernamePath) + + if req.OIDCConnectProviderName == "" { + req.OIDCConnectProviderName = "OIDC" + } + if req.OIDCConnectClientID == "" { + response.BadRequest(c, "OIDC Client ID is required when enabled") + return + } + if req.OIDCConnectIssuerURL == "" { + response.BadRequest(c, "OIDC Issuer URL is required when enabled") + return + } + if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectIssuerURL); err != nil { + response.BadRequest(c, "OIDC Issuer URL must be an absolute http(s) URL") + return + } + if req.OIDCConnectDiscoveryURL != "" { + if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectDiscoveryURL); err != nil { + response.BadRequest(c, "OIDC Discovery URL must be an absolute http(s) URL") + return + } + } + if req.OIDCConnectAuthorizeURL != "" { + if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectAuthorizeURL); err != nil { + response.BadRequest(c, "OIDC Authorize URL must be an absolute http(s) URL") + return + } + } + if req.OIDCConnectTokenURL != "" { + if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectTokenURL); err != nil { + response.BadRequest(c, "OIDC Token URL must be an absolute http(s) URL") + return + } + } + if req.OIDCConnectUserInfoURL != "" { + if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectUserInfoURL); err != nil { + response.BadRequest(c, "OIDC UserInfo URL must be an absolute http(s) URL") + return + } + } + if req.OIDCConnectRedirectURL == "" { + response.BadRequest(c, "OIDC Redirect URL is required when enabled") + return + } + if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectRedirectURL); err != nil { + response.BadRequest(c, "OIDC Redirect URL must be an absolute http(s) URL") + return + } + if req.OIDCConnectFrontendRedirectURL == "" { + response.BadRequest(c, "OIDC Frontend Redirect URL is required when enabled") + return + } + if err := config.ValidateFrontendRedirectURL(req.OIDCConnectFrontendRedirectURL); err != nil { + response.BadRequest(c, "OIDC Frontend Redirect URL is invalid") + return + } + if !scopesContainOpenID(req.OIDCConnectScopes) { + response.BadRequest(c, "OIDC scopes must contain openid") + return + } + switch req.OIDCConnectTokenAuthMethod { + case "", "client_secret_post", "client_secret_basic", "none": + default: + response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none") + return + } + if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE { + response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none") + return + } + if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 { + response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600") + return + } + if req.OIDCConnectValidateIDToken { + if req.OIDCConnectAllowedSigningAlgs == "" { + response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") + return + } + } + if req.OIDCConnectJWKSURL != "" { + if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil { + response.BadRequest(c, "OIDC JWKS URL must be an absolute http(s) URL") + return + } + } + if req.OIDCConnectTokenAuthMethod == "" || req.OIDCConnectTokenAuthMethod == "client_secret_post" || req.OIDCConnectTokenAuthMethod == "client_secret_basic" { + if req.OIDCConnectClientSecret == "" { + if previousSettings.OIDCConnectClientSecret == "" { + response.BadRequest(c, "OIDC Client Secret is required when enabled") + return + } + req.OIDCConnectClientSecret = previousSettings.OIDCConnectClientSecret + } + } + } + // “购买订阅”页面配置验证 purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled if req.PurchaseSubscriptionEnabled != nil { @@ -554,6 +725,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { LinuxDoConnectClientID: req.LinuxDoConnectClientID, LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + OIDCConnectEnabled: req.OIDCConnectEnabled, + OIDCConnectProviderName: req.OIDCConnectProviderName, + OIDCConnectClientID: req.OIDCConnectClientID, + OIDCConnectClientSecret: req.OIDCConnectClientSecret, + OIDCConnectIssuerURL: req.OIDCConnectIssuerURL, + OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL, + OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL, + OIDCConnectTokenURL: req.OIDCConnectTokenURL, + OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL, + OIDCConnectJWKSURL: req.OIDCConnectJWKSURL, + OIDCConnectScopes: req.OIDCConnectScopes, + OIDCConnectRedirectURL: req.OIDCConnectRedirectURL, + OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL, + OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod, + OIDCConnectUsePKCE: req.OIDCConnectUsePKCE, + OIDCConnectValidateIDToken: req.OIDCConnectValidateIDToken, + OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs, + OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds, + OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified, + OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath, + OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath, + OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath, SiteName: req.SiteName, SiteLogo: req.SiteLogo, SiteSubtitle: req.SiteSubtitle, @@ -669,6 +862,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL, + OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled, + OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName, + OIDCConnectClientID: updatedSettings.OIDCConnectClientID, + OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured, + OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL, + OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL, + OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL, + OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL, + OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL, + OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL, + OIDCConnectScopes: updatedSettings.OIDCConnectScopes, + OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL, + OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL, + OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod, + OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE, + OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken, + OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs, + OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds, + OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified, + OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath, + OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath, + OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath, SiteName: updatedSettings.SiteName, SiteLogo: updatedSettings.SiteLogo, SiteSubtitle: updatedSettings.SiteSubtitle, @@ -787,6 +1002,72 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL { changed = append(changed, "linuxdo_connect_redirect_url") } + if before.OIDCConnectEnabled != after.OIDCConnectEnabled { + changed = append(changed, "oidc_connect_enabled") + } + if before.OIDCConnectProviderName != after.OIDCConnectProviderName { + changed = append(changed, "oidc_connect_provider_name") + } + if before.OIDCConnectClientID != after.OIDCConnectClientID { + changed = append(changed, "oidc_connect_client_id") + } + if req.OIDCConnectClientSecret != "" { + changed = append(changed, "oidc_connect_client_secret") + } + if before.OIDCConnectIssuerURL != after.OIDCConnectIssuerURL { + changed = append(changed, "oidc_connect_issuer_url") + } + if before.OIDCConnectDiscoveryURL != after.OIDCConnectDiscoveryURL { + changed = append(changed, "oidc_connect_discovery_url") + } + if before.OIDCConnectAuthorizeURL != after.OIDCConnectAuthorizeURL { + changed = append(changed, "oidc_connect_authorize_url") + } + if before.OIDCConnectTokenURL != after.OIDCConnectTokenURL { + changed = append(changed, "oidc_connect_token_url") + } + if before.OIDCConnectUserInfoURL != after.OIDCConnectUserInfoURL { + changed = append(changed, "oidc_connect_userinfo_url") + } + if before.OIDCConnectJWKSURL != after.OIDCConnectJWKSURL { + changed = append(changed, "oidc_connect_jwks_url") + } + if before.OIDCConnectScopes != after.OIDCConnectScopes { + changed = append(changed, "oidc_connect_scopes") + } + if before.OIDCConnectRedirectURL != after.OIDCConnectRedirectURL { + changed = append(changed, "oidc_connect_redirect_url") + } + if before.OIDCConnectFrontendRedirectURL != after.OIDCConnectFrontendRedirectURL { + changed = append(changed, "oidc_connect_frontend_redirect_url") + } + if before.OIDCConnectTokenAuthMethod != after.OIDCConnectTokenAuthMethod { + changed = append(changed, "oidc_connect_token_auth_method") + } + if before.OIDCConnectUsePKCE != after.OIDCConnectUsePKCE { + changed = append(changed, "oidc_connect_use_pkce") + } + if before.OIDCConnectValidateIDToken != after.OIDCConnectValidateIDToken { + changed = append(changed, "oidc_connect_validate_id_token") + } + if before.OIDCConnectAllowedSigningAlgs != after.OIDCConnectAllowedSigningAlgs { + changed = append(changed, "oidc_connect_allowed_signing_algs") + } + if before.OIDCConnectClockSkewSeconds != after.OIDCConnectClockSkewSeconds { + changed = append(changed, "oidc_connect_clock_skew_seconds") + } + if before.OIDCConnectRequireEmailVerified != after.OIDCConnectRequireEmailVerified { + changed = append(changed, "oidc_connect_require_email_verified") + } + if before.OIDCConnectUserInfoEmailPath != after.OIDCConnectUserInfoEmailPath { + changed = append(changed, "oidc_connect_userinfo_email_path") + } + if before.OIDCConnectUserInfoIDPath != after.OIDCConnectUserInfoIDPath { + changed = append(changed, "oidc_connect_userinfo_id_path") + } + if before.OIDCConnectUserInfoUsernamePath != after.OIDCConnectUserInfoUsernamePath { + changed = append(changed, "oidc_connect_userinfo_username_path") + } if before.SiteName != after.SiteName { changed = append(changed, "site_name") } diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go new file mode 100644 index 00000000..f46fb850 --- /dev/null +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -0,0 +1,865 @@ +package handler + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "log" + "math/big" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/imroc/req/v3" + "github.com/tidwall/gjson" +) + +const ( + oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc" + oidcOAuthStateCookieName = "oidc_oauth_state" + oidcOAuthVerifierCookie = "oidc_oauth_verifier" + oidcOAuthRedirectCookie = "oidc_oauth_redirect" + oidcOAuthNonceCookie = "oidc_oauth_nonce" + oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes + oidcOAuthDefaultRedirectTo = "/dashboard" + oidcOAuthDefaultFrontendCB = "/auth/oidc/callback" +) + +type oidcTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` + IDToken string `json:"id_token,omitempty"` +} + +type oidcTokenExchangeError struct { + StatusCode int + ProviderError string + ProviderDescription string + Body string +} + +func (e *oidcTokenExchangeError) Error() string { + if e == nil { + return "" + } + parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)} + if strings.TrimSpace(e.ProviderError) != "" { + parts = append(parts, "error="+strings.TrimSpace(e.ProviderError)) + } + if strings.TrimSpace(e.ProviderDescription) != "" { + parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription)) + } + return strings.Join(parts, " ") +} + +type oidcIDTokenClaims struct { + Email string `json:"email,omitempty"` + EmailVerified *bool `json:"email_verified,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty"` + Name string `json:"name,omitempty"` + Nonce string `json:"nonce,omitempty"` + Azp string `json:"azp,omitempty"` + jwt.RegisteredClaims +} + +type oidcUserInfoClaims struct { + Email string + Username string + Subject string + EmailVerified *bool +} + +type oidcJWKSet struct { + Keys []oidcJWK `json:"keys"` +} + +type oidcJWK struct { + Kty string `json:"kty"` + Kid string `json:"kid"` + Use string `json:"use"` + Alg string `json:"alg"` + + N string `json:"n"` + E string `json:"e"` + + Crv string `json:"crv"` + X string `json:"x"` + Y string `json:"y"` +} + +// OIDCOAuthStart 启动通用 OIDC OAuth 登录流程。 +// GET /api/v1/auth/oauth/oidc/start?redirect=/dashboard +func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) { + cfg, err := h.getOIDCOAuthConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = oidcOAuthDefaultRedirectTo + } + + secureCookie := isRequestHTTPS(c) + oidcSetCookie(c, oidcOAuthStateCookieName, encodeCookieValue(state), oidcOAuthCookieMaxAgeSec, secureCookie) + oidcSetCookie(c, oidcOAuthRedirectCookie, encodeCookieValue(redirectTo), oidcOAuthCookieMaxAgeSec, secureCookie) + + codeChallenge := "" + if cfg.UsePKCE { + verifier, genErr := oauth.GenerateCodeVerifier() + if genErr != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr)) + return + } + codeChallenge = oauth.GenerateCodeChallenge(verifier) + oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie) + } + + nonce := "" + if cfg.ValidateIDToken { + nonce, err = oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err)) + return + } + oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie) + } + + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")) + return + } + + authURL, err := buildOIDCAuthorizeURL(cfg, state, nonce, codeChallenge, redirectURI) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// OIDCOAuthCallback 处理 OIDC 回调:校验 id_token、创建/登录用户并重定向到前端。 +// GET /api/v1/auth/oauth/oidc/callback?code=...&state=... +func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { + cfg, cfgErr := h.getOIDCOAuthConfig(c.Request.Context()) + if cfgErr != nil { + response.ErrorFrom(c, cfgErr) + return + } + + frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL) + if frontendCallback == "" { + frontendCallback = oidcOAuthDefaultFrontendCB + } + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie) + oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie) + oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie) + oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, oidcOAuthStateCookieName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, oidcOAuthRedirectCookie) + redirectTo = sanitizeFrontendRedirectPath(redirectTo) + if redirectTo == "" { + redirectTo = oidcOAuthDefaultRedirectTo + } + + codeVerifier := "" + if cfg.UsePKCE { + codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return + } + } + + expectedNonce := "" + if cfg.ValidateIDToken { + expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie) + if expectedNonce == "" { + redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "") + return + } + } + + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "") + return + } + + tokenResp, err := oidcExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier) + if err != nil { + description := "" + var exchangeErr *oidcTokenExchangeError + if errors.As(err, &exchangeErr) && exchangeErr != nil { + log.Printf( + "[OIDC OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s", + exchangeErr.StatusCode, + exchangeErr.ProviderError, + exchangeErr.ProviderDescription, + truncateLogValue(exchangeErr.Body, 2048), + ) + description = exchangeErr.Error() + } else { + log.Printf("[OIDC OAuth] token exchange failed: %v", err) + description = err.Error() + } + redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description)) + return + } + + if cfg.ValidateIDToken && strings.TrimSpace(tokenResp.IDToken) == "" { + redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "") + return + } + + idClaims, err := oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce) + if err != nil { + log.Printf("[OIDC OAuth] id_token validation failed: %v", err) + redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "") + return + } + + userInfoClaims, err := oidcFetchUserInfo(c.Request.Context(), cfg, tokenResp) + if err != nil { + log.Printf("[OIDC OAuth] userinfo fetch failed: %v", err) + redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") + return + } + + subject := strings.TrimSpace(idClaims.Subject) + if subject == "" { + subject = strings.TrimSpace(userInfoClaims.Subject) + } + if subject == "" { + redirectOAuthError(c, frontendCallback, "missing_subject", "missing subject claim", "") + return + } + issuer := strings.TrimSpace(idClaims.Issuer) + if issuer == "" { + issuer = strings.TrimSpace(cfg.IssuerURL) + } + if issuer == "" { + redirectOAuthError(c, frontendCallback, "missing_issuer", "missing issuer claim", "") + return + } + + emailVerified := userInfoClaims.EmailVerified + if emailVerified == nil { + emailVerified = idClaims.EmailVerified + } + if cfg.RequireEmailVerified { + if emailVerified == nil || !*emailVerified { + redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "") + return + } + } + + identityKey := oidcIdentityKey(issuer, subject) + email := oidcSyntheticEmailFromIdentityKey(identityKey) + username := firstNonEmpty( + userInfoClaims.Username, + idClaims.PreferredUsername, + idClaims.Name, + oidcFallbackUsername(subject), + ) + + // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + if err != nil { + if errors.Is(err, service.ErrOAuthInvitationRequired) { + pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username) + if tokenErr != nil { + redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "") + return + } + fragment := url.Values{} + fragment.Set("error", "invitation_required") + fragment.Set("pending_oauth_token", pendingToken) + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) + return + } + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + fragment := url.Values{} + fragment.Set("access_token", tokenPair.AccessToken) + fragment.Set("refresh_token", tokenPair.RefreshToken) + fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) + fragment.Set("token_type", "Bearer") + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) +} + +type completeOIDCOAuthRequest struct { + PendingOAuthToken string `json:"pending_oauth_token" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` +} + +// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating +// the invitation code and creating the user account. +// POST /api/v1/auth/oauth/oidc/complete-registration +func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { + var req completeOIDCOAuthRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()}) + return + } + + email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"}) + return + } + + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { + response.ErrorFrom(c, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + +func (h *AuthHandler) getOIDCOAuthConfig(ctx context.Context) (config.OIDCConnectConfig, error) { + if h != nil && h.settingSvc != nil { + return h.settingSvc.GetOIDCConnectOAuthConfig(ctx) + } + if h == nil || h.cfg == nil { + return config.OIDCConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") + } + if !h.cfg.OIDC.Enabled { + return config.OIDCConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") + } + return h.cfg.OIDC, nil +} + +func oidcExchangeCode( + ctx context.Context, + cfg config.OIDCConnectConfig, + code string, + redirectURI string, + codeVerifier string, +) (*oidcTokenResponse, error) { + client := req.C().SetTimeout(30 * time.Second) + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", cfg.ClientID) + form.Set("code", code) + form.Set("redirect_uri", redirectURI) + if cfg.UsePKCE { + form.Set("code_verifier", codeVerifier) + } + + r := client.R(). + SetContext(ctx). + SetHeader("Accept", "application/json") + + switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) { + case "", "client_secret_post": + form.Set("client_secret", cfg.ClientSecret) + case "client_secret_basic": + r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) + case "none": + default: + return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod) + } + + resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL) + if err != nil { + return nil, fmt.Errorf("request token: %w", err) + } + body := strings.TrimSpace(resp.String()) + if !resp.IsSuccessState() { + providerErr, providerDesc := parseOAuthProviderError(body) + return nil, &oidcTokenExchangeError{ + StatusCode: resp.StatusCode, + ProviderError: providerErr, + ProviderDescription: providerDesc, + Body: body, + } + } + + tokenResp, ok := oidcParseTokenResponse(body) + if !ok { + return nil, &oidcTokenExchangeError{StatusCode: resp.StatusCode, Body: body} + } + if strings.TrimSpace(tokenResp.TokenType) == "" { + tokenResp.TokenType = "Bearer" + } + if strings.TrimSpace(tokenResp.AccessToken) == "" && strings.TrimSpace(tokenResp.IDToken) == "" { + return nil, &oidcTokenExchangeError{StatusCode: resp.StatusCode, Body: body} + } + return tokenResp, nil +} + +func oidcParseTokenResponse(body string) (*oidcTokenResponse, bool) { + body = strings.TrimSpace(body) + if body == "" { + return nil, false + } + + accessToken := strings.TrimSpace(getGJSON(body, "access_token")) + idToken := strings.TrimSpace(getGJSON(body, "id_token")) + if accessToken != "" || idToken != "" { + tokenType := strings.TrimSpace(getGJSON(body, "token_type")) + refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token")) + scope := strings.TrimSpace(getGJSON(body, "scope")) + expiresIn := gjson.Get(body, "expires_in").Int() + return &oidcTokenResponse{ + AccessToken: accessToken, + TokenType: tokenType, + ExpiresIn: expiresIn, + RefreshToken: refreshToken, + Scope: scope, + IDToken: idToken, + }, true + } + + values, err := url.ParseQuery(body) + if err != nil { + return nil, false + } + accessToken = strings.TrimSpace(values.Get("access_token")) + idToken = strings.TrimSpace(values.Get("id_token")) + if accessToken == "" && idToken == "" { + return nil, false + } + expiresIn := int64(0) + if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" { + if v, parseErr := strconv.ParseInt(raw, 10, 64); parseErr == nil { + expiresIn = v + } + } + return &oidcTokenResponse{ + AccessToken: accessToken, + TokenType: strings.TrimSpace(values.Get("token_type")), + ExpiresIn: expiresIn, + RefreshToken: strings.TrimSpace(values.Get("refresh_token")), + Scope: strings.TrimSpace(values.Get("scope")), + IDToken: idToken, + }, true +} + +func oidcFetchUserInfo( + ctx context.Context, + cfg config.OIDCConnectConfig, + token *oidcTokenResponse, +) (*oidcUserInfoClaims, error) { + if strings.TrimSpace(cfg.UserInfoURL) == "" { + return &oidcUserInfoClaims{}, nil + } + if token == nil || strings.TrimSpace(token.AccessToken) == "" { + return nil, errors.New("missing access_token for userinfo request") + } + + client := req.C().SetTimeout(30 * time.Second) + authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken) + if err != nil { + return nil, fmt.Errorf("invalid token for userinfo request: %w", err) + } + + resp, err := client.R(). + SetContext(ctx). + SetHeader("Accept", "application/json"). + SetHeader("Authorization", authorization). + Get(cfg.UserInfoURL) + if err != nil { + return nil, fmt.Errorf("request userinfo: %w", err) + } + if !resp.IsSuccessState() { + return nil, fmt.Errorf("userinfo status=%d", resp.StatusCode) + } + + return oidcParseUserInfo(resp.String(), cfg), nil +} + +func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoClaims { + claims := &oidcUserInfoClaims{} + claims.Email = firstNonEmpty( + getGJSON(body, cfg.UserInfoEmailPath), + getGJSON(body, "email"), + getGJSON(body, "user.email"), + getGJSON(body, "data.email"), + getGJSON(body, "attributes.email"), + ) + claims.Username = firstNonEmpty( + getGJSON(body, cfg.UserInfoUsernamePath), + getGJSON(body, "preferred_username"), + getGJSON(body, "username"), + getGJSON(body, "name"), + getGJSON(body, "user.username"), + getGJSON(body, "user.name"), + ) + claims.Subject = firstNonEmpty( + getGJSON(body, cfg.UserInfoIDPath), + getGJSON(body, "sub"), + getGJSON(body, "id"), + getGJSON(body, "user_id"), + getGJSON(body, "uid"), + getGJSON(body, "user.id"), + ) + if verified, ok := getGJSONBool(body, "email_verified"); ok { + claims.EmailVerified = &verified + } + claims.Email = strings.TrimSpace(claims.Email) + claims.Username = strings.TrimSpace(claims.Username) + claims.Subject = strings.TrimSpace(claims.Subject) + return claims +} + +func getGJSONBool(body string, path string) (bool, bool) { + path = strings.TrimSpace(path) + if path == "" { + return false, false + } + res := gjson.Get(body, path) + if !res.Exists() { + return false, false + } + return res.Bool(), true +} + +func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChallenge, redirectURI string) (string, error) { + u, err := url.Parse(cfg.AuthorizeURL) + if err != nil { + return "", fmt.Errorf("parse authorize_url: %w", err) + } + + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", cfg.ClientID) + q.Set("redirect_uri", redirectURI) + if strings.TrimSpace(cfg.Scopes) != "" { + q.Set("scope", cfg.Scopes) + } + q.Set("state", state) + if strings.TrimSpace(nonce) != "" { + q.Set("nonce", nonce) + } + if cfg.UsePKCE { + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + } + + u.RawQuery = q.Encode() + return u.String(), nil +} + +func oidcParseAndValidateIDToken(ctx context.Context, cfg config.OIDCConnectConfig, idToken string, expectedNonce string) (*oidcIDTokenClaims, error) { + idToken = strings.TrimSpace(idToken) + if idToken == "" { + return nil, errors.New("missing id_token") + } + allowed := oidcAllowedSigningAlgs(cfg.AllowedSigningAlgs) + if len(allowed) == 0 { + return nil, errors.New("empty allowed signing algorithms") + } + + jwks, err := oidcFetchJWKSet(ctx, cfg.JWKSURL) + if err != nil { + return nil, err + } + leeway := time.Duration(cfg.ClockSkewSeconds) * time.Second + claims := &oidcIDTokenClaims{} + + parsed, err := jwt.ParseWithClaims( + idToken, + claims, + func(token *jwt.Token) (any, error) { + alg := strings.TrimSpace(token.Method.Alg()) + if !containsString(allowed, alg) { + return nil, fmt.Errorf("unexpected signing algorithm: %s", alg) + } + kid, _ := token.Header["kid"].(string) + return oidcFindPublicKey(jwks, strings.TrimSpace(kid), alg) + }, + jwt.WithValidMethods(allowed), + jwt.WithAudience(cfg.ClientID), + jwt.WithIssuer(cfg.IssuerURL), + jwt.WithLeeway(leeway), + ) + if err != nil { + return nil, err + } + if !parsed.Valid { + return nil, errors.New("id_token invalid") + } + if strings.TrimSpace(claims.Subject) == "" { + return nil, errors.New("id_token missing sub") + } + if expectedNonce != "" && strings.TrimSpace(claims.Nonce) != strings.TrimSpace(expectedNonce) { + return nil, errors.New("id_token nonce mismatch") + } + if len(claims.Audience) > 1 { + if strings.TrimSpace(claims.Azp) == "" || strings.TrimSpace(claims.Azp) != strings.TrimSpace(cfg.ClientID) { + return nil, errors.New("id_token azp mismatch") + } + } + return claims, nil +} + +func oidcAllowedSigningAlgs(raw string) []string { + if strings.TrimSpace(raw) == "" { + return []string{"RS256", "ES256", "PS256"} + } + seen := make(map[string]struct{}) + out := make([]string, 0, 4) + for _, part := range strings.Split(raw, ",") { + alg := strings.ToUpper(strings.TrimSpace(part)) + if alg == "" { + continue + } + if _, ok := seen[alg]; ok { + continue + } + seen[alg] = struct{}{} + out = append(out, alg) + } + return out +} + +func oidcFetchJWKSet(ctx context.Context, jwksURL string) (*oidcJWKSet, error) { + jwksURL = strings.TrimSpace(jwksURL) + if jwksURL == "" { + return nil, errors.New("missing jwks_url") + } + resp, err := req.C(). + SetTimeout(30*time.Second). + R(). + SetContext(ctx). + SetHeader("Accept", "application/json"). + Get(jwksURL) + if err != nil { + return nil, fmt.Errorf("request jwks: %w", err) + } + if !resp.IsSuccessState() { + return nil, fmt.Errorf("jwks status=%d", resp.StatusCode) + } + set := &oidcJWKSet{} + if err := json.Unmarshal(resp.Bytes(), set); err != nil { + return nil, fmt.Errorf("parse jwks: %w", err) + } + if len(set.Keys) == 0 { + return nil, errors.New("jwks empty keys") + } + return set, nil +} + +func oidcFindPublicKey(set *oidcJWKSet, kid, alg string) (any, error) { + if set == nil { + return nil, errors.New("jwks not loaded") + } + alg = strings.ToUpper(strings.TrimSpace(alg)) + kid = strings.TrimSpace(kid) + + var lastErr error + for i := range set.Keys { + k := set.Keys[i] + if strings.TrimSpace(k.Use) != "" && !strings.EqualFold(strings.TrimSpace(k.Use), "sig") { + continue + } + if kid != "" && strings.TrimSpace(k.Kid) != kid { + continue + } + if strings.TrimSpace(k.Alg) != "" && !strings.EqualFold(strings.TrimSpace(k.Alg), alg) { + continue + } + pk, err := k.publicKey() + if err != nil { + lastErr = err + continue + } + if pk != nil { + return pk, nil + } + } + if lastErr != nil { + return nil, lastErr + } + if kid != "" { + return nil, fmt.Errorf("jwk not found for kid=%s", kid) + } + return nil, errors.New("jwk not found") +} + +func (k oidcJWK) publicKey() (any, error) { + switch strings.ToUpper(strings.TrimSpace(k.Kty)) { + case "RSA": + n, err := decodeBase64URLBigInt(k.N) + if err != nil { + return nil, fmt.Errorf("decode rsa n: %w", err) + } + eBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(k.E)) + if err != nil { + return nil, fmt.Errorf("decode rsa e: %w", err) + } + if len(eBytes) == 0 { + return nil, errors.New("empty rsa e") + } + e := 0 + for _, b := range eBytes { + e = (e << 8) | int(b) + } + if e <= 0 { + return nil, errors.New("invalid rsa exponent") + } + if n.Sign() <= 0 { + return nil, errors.New("invalid rsa modulus") + } + return &rsa.PublicKey{N: n, E: e}, nil + case "EC": + var curve elliptic.Curve + switch strings.TrimSpace(k.Crv) { + case "P-256": + curve = elliptic.P256() + case "P-384": + curve = elliptic.P384() + case "P-521": + curve = elliptic.P521() + default: + return nil, fmt.Errorf("unsupported ec curve: %s", k.Crv) + } + x, err := decodeBase64URLBigInt(k.X) + if err != nil { + return nil, fmt.Errorf("decode ec x: %w", err) + } + y, err := decodeBase64URLBigInt(k.Y) + if err != nil { + return nil, fmt.Errorf("decode ec y: %w", err) + } + if !curve.IsOnCurve(x, y) { + return nil, errors.New("ec point is not on curve") + } + return &ecdsa.PublicKey{Curve: curve, X: x, Y: y}, nil + default: + return nil, fmt.Errorf("unsupported jwk kty: %s", k.Kty) + } +} + +func decodeBase64URLBigInt(raw string) (*big.Int, error) { + buf, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(raw)) + if err != nil { + return nil, err + } + if len(buf) == 0 { + return nil, errors.New("empty value") + } + return new(big.Int).SetBytes(buf), nil +} + +func containsString(values []string, target string) bool { + target = strings.TrimSpace(target) + for _, v := range values { + if strings.EqualFold(strings.TrimSpace(v), target) { + return true + } + } + return false +} + +func oidcIdentityKey(issuer, subject string) string { + issuer = strings.TrimSpace(strings.ToLower(issuer)) + subject = strings.TrimSpace(subject) + return issuer + "\x1f" + subject +} + +func oidcSyntheticEmailFromIdentityKey(identityKey string) string { + identityKey = strings.TrimSpace(identityKey) + if identityKey == "" { + return "" + } + sum := sha256.Sum256([]byte(identityKey)) + return "oidc-" + hex.EncodeToString(sum[:16]) + service.OIDCConnectSyntheticEmailDomain +} + +func oidcFallbackUsername(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "oidc_user" + } + sum := sha256.Sum256([]byte(subject)) + return "oidc_" + hex.EncodeToString(sum[:])[:12] +} + +func oidcSetCookie(c *gin.Context, name, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: oidcOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func oidcClearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: oidcOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go new file mode 100644 index 00000000..1f50dd49 --- /dev/null +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -0,0 +1,106 @@ +package handler + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "math/big" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" +) + +func TestOIDCSyntheticEmailStableAndDistinct(t *testing.T) { + k1 := oidcIdentityKey("https://issuer.example.com", "subject-a") + k2 := oidcIdentityKey("https://issuer.example.com", "subject-b") + + e1 := oidcSyntheticEmailFromIdentityKey(k1) + e1Again := oidcSyntheticEmailFromIdentityKey(k1) + e2 := oidcSyntheticEmailFromIdentityKey(k2) + + require.Equal(t, e1, e1Again) + require.NotEqual(t, e1, e2) + require.Contains(t, e1, "@oidc-connect.invalid") +} + +func TestBuildOIDCAuthorizeURLIncludesNonceAndPKCE(t *testing.T) { + cfg := config.OIDCConnectConfig{ + AuthorizeURL: "https://issuer.example.com/auth", + ClientID: "cid", + Scopes: "openid email profile", + UsePKCE: true, + } + + u, err := buildOIDCAuthorizeURL(cfg, "state123", "nonce123", "challenge123", "https://app.example.com/callback") + require.NoError(t, err) + require.Contains(t, u, "nonce=nonce123") + require.Contains(t, u, "code_challenge=challenge123") + require.Contains(t, u, "code_challenge_method=S256") + require.Contains(t, u, "scope=openid+email+profile") +} + +func TestOIDCParseAndValidateIDToken(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + kid := "kid-1" + jwks := oidcJWKSet{Keys: []oidcJWK{buildRSAJWK(kid, &priv.PublicKey)}} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.NoError(t, json.NewEncoder(w).Encode(jwks)) + })) + defer srv.Close() + + now := time.Now() + claims := oidcIDTokenClaims{ + Nonce: "nonce-ok", + Azp: "client-1", + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "https://issuer.example.com", + Subject: "subject-1", + Audience: jwt.ClaimStrings{"client-1", "another-aud"}, + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now.Add(-30 * time.Second)), + ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = kid + signed, err := tok.SignedString(priv) + require.NoError(t, err) + + cfg := config.OIDCConnectConfig{ + ClientID: "client-1", + IssuerURL: "https://issuer.example.com", + JWKSURL: srv.URL, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + } + + parsed, err := oidcParseAndValidateIDToken(context.Background(), cfg, signed, "nonce-ok") + require.NoError(t, err) + require.Equal(t, "subject-1", parsed.Subject) + require.Equal(t, "https://issuer.example.com", parsed.Issuer) + + _, err = oidcParseAndValidateIDToken(context.Background(), cfg, signed, "bad-nonce") + require.Error(t, err) +} + +func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { + n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes()) + e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes()) + return oidcJWK{ + Kty: "RSA", + Kid: kid, + Use: "sig", + Alg: "RS256", + N: n, + E: e, + } +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 73707f79..c8fc3b5d 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -51,6 +51,29 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + OIDCConnectEnabled bool `json:"oidc_connect_enabled"` + OIDCConnectProviderName string `json:"oidc_connect_provider_name"` + OIDCConnectClientID string `json:"oidc_connect_client_id"` + OIDCConnectClientSecretConfigured bool `json:"oidc_connect_client_secret_configured"` + OIDCConnectIssuerURL string `json:"oidc_connect_issuer_url"` + OIDCConnectDiscoveryURL string `json:"oidc_connect_discovery_url"` + OIDCConnectAuthorizeURL string `json:"oidc_connect_authorize_url"` + OIDCConnectTokenURL string `json:"oidc_connect_token_url"` + OIDCConnectUserInfoURL string `json:"oidc_connect_userinfo_url"` + OIDCConnectJWKSURL string `json:"oidc_connect_jwks_url"` + OIDCConnectScopes string `json:"oidc_connect_scopes"` + OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"` + OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"` + OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"` + OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"` + OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"` + OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"` + OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"` + OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"` + OIDCConnectUserInfoEmailPath string `json:"oidc_connect_userinfo_email_path"` + OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"` + OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"` + SiteName string `json:"site_name"` SiteLogo string `json:"site_logo"` SiteSubtitle string `json:"site_subtitle"` @@ -128,6 +151,9 @@ type PublicSettings struct { CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` + OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` + SoraClientEnabled bool `json:"sora_client_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` Version string `json:"version"` } diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 977c2301..8b536877 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -54,6 +54,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + OIDCOAuthEnabled: settings.OIDCOAuthEnabled, + OIDCOAuthProviderName: settings.OIDCOAuthProviderName, + SoraClientEnabled: settings.SoraClientEnabled, BackendModeEnabled: settings.BackendModeEnabled, Version: h.version, }) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 24f60f27..fee879a4 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -462,6 +462,28 @@ func TestAPIContracts(t *testing.T) { service.SettingKeyTurnstileSiteKey: "site-key", service.SettingKeyTurnstileSecretKey: "secret-key", + service.SettingKeyOIDCConnectEnabled: "false", + service.SettingKeyOIDCConnectProviderName: "OIDC", + service.SettingKeyOIDCConnectClientID: "", + service.SettingKeyOIDCConnectIssuerURL: "", + service.SettingKeyOIDCConnectDiscoveryURL: "", + service.SettingKeyOIDCConnectAuthorizeURL: "", + service.SettingKeyOIDCConnectTokenURL: "", + service.SettingKeyOIDCConnectUserInfoURL: "", + service.SettingKeyOIDCConnectJWKSURL: "", + service.SettingKeyOIDCConnectScopes: "openid email profile", + service.SettingKeyOIDCConnectRedirectURL: "", + service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", + service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", + service.SettingKeyOIDCConnectUsePKCE: "false", + service.SettingKeyOIDCConnectValidateIDToken: "true", + service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", + service.SettingKeyOIDCConnectClockSkewSeconds: "120", + service.SettingKeyOIDCConnectRequireEmailVerified: "false", + service.SettingKeyOIDCConnectUserInfoEmailPath: "", + service.SettingKeyOIDCConnectUserInfoIDPath: "", + service.SettingKeyOIDCConnectUserInfoUsernamePath: "", + service.SettingKeySiteName: "Sub2API", service.SettingKeySiteLogo: "", service.SettingKeySiteSubtitle: "Subtitle", @@ -503,10 +525,32 @@ func TestAPIContracts(t *testing.T) { "turnstile_enabled": true, "turnstile_site_key": "site-key", "turnstile_secret_key_configured": true, - "linuxdo_connect_enabled": false, + "linuxdo_connect_enabled": false, "linuxdo_connect_client_id": "", "linuxdo_connect_client_secret_configured": false, "linuxdo_connect_redirect_url": "", + "oidc_connect_enabled": false, + "oidc_connect_provider_name": "OIDC", + "oidc_connect_client_id": "", + "oidc_connect_client_secret_configured": false, + "oidc_connect_issuer_url": "", + "oidc_connect_discovery_url": "", + "oidc_connect_authorize_url": "", + "oidc_connect_token_url": "", + "oidc_connect_userinfo_url": "", + "oidc_connect_jwks_url": "", + "oidc_connect_scopes": "openid email profile", + "oidc_connect_redirect_url": "", + "oidc_connect_frontend_redirect_url": "/auth/oidc/callback", + "oidc_connect_token_auth_method": "client_secret_post", + "oidc_connect_use_pkce": false, + "oidc_connect_validate_id_token": true, + "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256", + "oidc_connect_clock_skew_seconds": 120, + "oidc_connect_require_email_verified": false, + "oidc_connect_userinfo_email_path": "", + "oidc_connect_userinfo_id_path": "", + "oidc_connect_userinfo_username_path": "", "ops_monitoring_enabled": false, "ops_realtime_monitoring_enabled": true, "ops_query_mode_default": "auto", diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index a6c0ecf5..c143b030 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -70,6 +70,14 @@ func RegisterAuthRoutes( }), h.Auth.CompleteLinuxDoOAuthRegistration, ) + auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart) + auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback) + auth.POST("/oauth/oidc/complete-registration", + rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CompleteOIDCOAuthRegistration, + ) } // 公开设置(无需认证) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 6e524fb9..fd28cd42 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -833,7 +833,8 @@ func randomHexString(byteLength int) (string, error) { func isReservedEmail(email string) bool { normalized := strings.ToLower(strings.TrimSpace(email)) - return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) + return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) || + strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) } // GenerateToken 生成JWT access token diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 92be3e06..e194f921 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -71,6 +71,9 @@ const ( // LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。 const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" +// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。 +const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid" + // Setting keys const ( // 注册设置 @@ -105,6 +108,30 @@ const ( SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret" SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" + // Generic OIDC OAuth 登录设置 + SettingKeyOIDCConnectEnabled = "oidc_connect_enabled" + SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name" + SettingKeyOIDCConnectClientID = "oidc_connect_client_id" + SettingKeyOIDCConnectClientSecret = "oidc_connect_client_secret" + SettingKeyOIDCConnectIssuerURL = "oidc_connect_issuer_url" + SettingKeyOIDCConnectDiscoveryURL = "oidc_connect_discovery_url" + SettingKeyOIDCConnectAuthorizeURL = "oidc_connect_authorize_url" + SettingKeyOIDCConnectTokenURL = "oidc_connect_token_url" + SettingKeyOIDCConnectUserInfoURL = "oidc_connect_userinfo_url" + SettingKeyOIDCConnectJWKSURL = "oidc_connect_jwks_url" + SettingKeyOIDCConnectScopes = "oidc_connect_scopes" + SettingKeyOIDCConnectRedirectURL = "oidc_connect_redirect_url" + SettingKeyOIDCConnectFrontendRedirectURL = "oidc_connect_frontend_redirect_url" + SettingKeyOIDCConnectTokenAuthMethod = "oidc_connect_token_auth_method" + SettingKeyOIDCConnectUsePKCE = "oidc_connect_use_pkce" + SettingKeyOIDCConnectValidateIDToken = "oidc_connect_validate_id_token" + SettingKeyOIDCConnectAllowedSigningAlgs = "oidc_connect_allowed_signing_algs" + SettingKeyOIDCConnectClockSkewSeconds = "oidc_connect_clock_skew_seconds" + SettingKeyOIDCConnectRequireEmailVerified = "oidc_connect_require_email_verified" + SettingKeyOIDCConnectUserInfoEmailPath = "oidc_connect_userinfo_email_path" + SettingKeyOIDCConnectUserInfoIDPath = "oidc_connect_userinfo_id_path" + SettingKeyOIDCConnectUserInfoUsernamePath = "oidc_connect_userinfo_username_path" + // OEM设置 SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteLogo = "site_logo" // 网站Logo (base64) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 7d0ef5bd..37677fa5 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -16,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/imroc/req/v3" "golang.org/x/sync/singleflight" ) @@ -164,6 +165,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyCustomEndpoints, SettingKeyLinuxDoConnectEnabled, SettingKeyBackendModeEnabled, + SettingKeyOIDCConnectEnabled, + SettingKeyOIDCConnectProviderName, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -177,6 +180,19 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings } else { linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled } + oidcEnabled := false + if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok { + oidcEnabled = raw == "true" + } else { + oidcEnabled = s.cfg != nil && s.cfg.OIDC.Enabled + } + oidcProviderName := strings.TrimSpace(settings[SettingKeyOIDCConnectProviderName]) + if oidcProviderName == "" && s.cfg != nil { + oidcProviderName = strings.TrimSpace(s.cfg.OIDC.ProviderName) + } + if oidcProviderName == "" { + oidcProviderName = "OIDC" + } // Password reset requires email verification to be enabled emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" @@ -209,6 +225,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", + OIDCOAuthEnabled: oidcEnabled, + OIDCOAuthProviderName: oidcProviderName, }, nil } @@ -256,6 +274,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any CustomEndpoints json.RawMessage `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` + OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` + OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` Version string `json:"version,omitempty"` }{ RegistrationEnabled: settings.RegistrationEnabled, @@ -281,6 +301,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, BackendModeEnabled: settings.BackendModeEnabled, + OIDCOAuthEnabled: settings.OIDCOAuthEnabled, + OIDCOAuthProviderName: settings.OIDCOAuthProviderName, Version: s.version, }, nil } @@ -460,6 +482,32 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret } + // Generic OIDC OAuth 登录 + updates[SettingKeyOIDCConnectEnabled] = strconv.FormatBool(settings.OIDCConnectEnabled) + updates[SettingKeyOIDCConnectProviderName] = settings.OIDCConnectProviderName + updates[SettingKeyOIDCConnectClientID] = settings.OIDCConnectClientID + updates[SettingKeyOIDCConnectIssuerURL] = settings.OIDCConnectIssuerURL + updates[SettingKeyOIDCConnectDiscoveryURL] = settings.OIDCConnectDiscoveryURL + updates[SettingKeyOIDCConnectAuthorizeURL] = settings.OIDCConnectAuthorizeURL + updates[SettingKeyOIDCConnectTokenURL] = settings.OIDCConnectTokenURL + updates[SettingKeyOIDCConnectUserInfoURL] = settings.OIDCConnectUserInfoURL + updates[SettingKeyOIDCConnectJWKSURL] = settings.OIDCConnectJWKSURL + updates[SettingKeyOIDCConnectScopes] = settings.OIDCConnectScopes + updates[SettingKeyOIDCConnectRedirectURL] = settings.OIDCConnectRedirectURL + updates[SettingKeyOIDCConnectFrontendRedirectURL] = settings.OIDCConnectFrontendRedirectURL + updates[SettingKeyOIDCConnectTokenAuthMethod] = settings.OIDCConnectTokenAuthMethod + updates[SettingKeyOIDCConnectUsePKCE] = strconv.FormatBool(settings.OIDCConnectUsePKCE) + updates[SettingKeyOIDCConnectValidateIDToken] = strconv.FormatBool(settings.OIDCConnectValidateIDToken) + updates[SettingKeyOIDCConnectAllowedSigningAlgs] = settings.OIDCConnectAllowedSigningAlgs + updates[SettingKeyOIDCConnectClockSkewSeconds] = strconv.Itoa(settings.OIDCConnectClockSkewSeconds) + updates[SettingKeyOIDCConnectRequireEmailVerified] = strconv.FormatBool(settings.OIDCConnectRequireEmailVerified) + updates[SettingKeyOIDCConnectUserInfoEmailPath] = settings.OIDCConnectUserInfoEmailPath + updates[SettingKeyOIDCConnectUserInfoIDPath] = settings.OIDCConnectUserInfoIDPath + updates[SettingKeyOIDCConnectUserInfoUsernamePath] = settings.OIDCConnectUserInfoUsernamePath + if settings.OIDCConnectClientSecret != "" { + updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret + } + // OEM设置 updates[SettingKeySiteName] = settings.SiteName updates[SettingKeySiteLogo] = settings.SiteLogo @@ -826,6 +874,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyPurchaseSubscriptionURL: "", SettingKeyCustomMenuItems: "[]", SettingKeyCustomEndpoints: "[]", + SettingKeyOIDCConnectEnabled: "false", + SettingKeyOIDCConnectProviderName: "OIDC", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultSubscriptions: "[]", @@ -951,6 +1001,138 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != "" + // Generic OIDC 设置: + // - 兼容 config.yaml/env + // - 支持后台系统设置覆盖并持久化(存储于 DB) + oidcBase := config.OIDCConnectConfig{} + if s.cfg != nil { + oidcBase = s.cfg.OIDC + } + + if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok { + result.OIDCConnectEnabled = raw == "true" + } else { + result.OIDCConnectEnabled = oidcBase.Enabled + } + + if v, ok := settings[SettingKeyOIDCConnectProviderName]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectProviderName = strings.TrimSpace(v) + } else { + result.OIDCConnectProviderName = strings.TrimSpace(oidcBase.ProviderName) + } + if result.OIDCConnectProviderName == "" { + result.OIDCConnectProviderName = "OIDC" + } + + if v, ok := settings[SettingKeyOIDCConnectClientID]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectClientID = strings.TrimSpace(v) + } else { + result.OIDCConnectClientID = strings.TrimSpace(oidcBase.ClientID) + } + if v, ok := settings[SettingKeyOIDCConnectIssuerURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectIssuerURL = strings.TrimSpace(v) + } else { + result.OIDCConnectIssuerURL = strings.TrimSpace(oidcBase.IssuerURL) + } + if v, ok := settings[SettingKeyOIDCConnectDiscoveryURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectDiscoveryURL = strings.TrimSpace(v) + } else { + result.OIDCConnectDiscoveryURL = strings.TrimSpace(oidcBase.DiscoveryURL) + } + if v, ok := settings[SettingKeyOIDCConnectAuthorizeURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectAuthorizeURL = strings.TrimSpace(v) + } else { + result.OIDCConnectAuthorizeURL = strings.TrimSpace(oidcBase.AuthorizeURL) + } + if v, ok := settings[SettingKeyOIDCConnectTokenURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectTokenURL = strings.TrimSpace(v) + } else { + result.OIDCConnectTokenURL = strings.TrimSpace(oidcBase.TokenURL) + } + if v, ok := settings[SettingKeyOIDCConnectUserInfoURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectUserInfoURL = strings.TrimSpace(v) + } else { + result.OIDCConnectUserInfoURL = strings.TrimSpace(oidcBase.UserInfoURL) + } + if v, ok := settings[SettingKeyOIDCConnectJWKSURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectJWKSURL = strings.TrimSpace(v) + } else { + result.OIDCConnectJWKSURL = strings.TrimSpace(oidcBase.JWKSURL) + } + if v, ok := settings[SettingKeyOIDCConnectScopes]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectScopes = strings.TrimSpace(v) + } else { + result.OIDCConnectScopes = strings.TrimSpace(oidcBase.Scopes) + } + if v, ok := settings[SettingKeyOIDCConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectRedirectURL = strings.TrimSpace(v) + } else { + result.OIDCConnectRedirectURL = strings.TrimSpace(oidcBase.RedirectURL) + } + if v, ok := settings[SettingKeyOIDCConnectFrontendRedirectURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectFrontendRedirectURL = strings.TrimSpace(v) + } else { + result.OIDCConnectFrontendRedirectURL = strings.TrimSpace(oidcBase.FrontendRedirectURL) + } + if v, ok := settings[SettingKeyOIDCConnectTokenAuthMethod]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(v)) + } else { + result.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(oidcBase.TokenAuthMethod)) + } + if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok { + result.OIDCConnectUsePKCE = raw == "true" + } else { + result.OIDCConnectUsePKCE = oidcBase.UsePKCE + } + if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { + result.OIDCConnectValidateIDToken = raw == "true" + } else { + result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken + } + if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) + } else { + result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(oidcBase.AllowedSigningAlgs) + } + clockSkewSet := false + if raw, ok := settings[SettingKeyOIDCConnectClockSkewSeconds]; ok && strings.TrimSpace(raw) != "" { + if parsed, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil { + result.OIDCConnectClockSkewSeconds = parsed + clockSkewSet = true + } + } + if !clockSkewSet { + result.OIDCConnectClockSkewSeconds = oidcBase.ClockSkewSeconds + } + if !clockSkewSet && result.OIDCConnectClockSkewSeconds == 0 { + result.OIDCConnectClockSkewSeconds = 120 + } + if raw, ok := settings[SettingKeyOIDCConnectRequireEmailVerified]; ok { + result.OIDCConnectRequireEmailVerified = raw == "true" + } else { + result.OIDCConnectRequireEmailVerified = oidcBase.RequireEmailVerified + } + if v, ok := settings[SettingKeyOIDCConnectUserInfoEmailPath]; ok { + result.OIDCConnectUserInfoEmailPath = strings.TrimSpace(v) + } else { + result.OIDCConnectUserInfoEmailPath = strings.TrimSpace(oidcBase.UserInfoEmailPath) + } + if v, ok := settings[SettingKeyOIDCConnectUserInfoIDPath]; ok { + result.OIDCConnectUserInfoIDPath = strings.TrimSpace(v) + } else { + result.OIDCConnectUserInfoIDPath = strings.TrimSpace(oidcBase.UserInfoIDPath) + } + if v, ok := settings[SettingKeyOIDCConnectUserInfoUsernamePath]; ok { + result.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(v) + } else { + result.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(oidcBase.UserInfoUsernamePath) + } + result.OIDCConnectClientSecret = strings.TrimSpace(settings[SettingKeyOIDCConnectClientSecret]) + if result.OIDCConnectClientSecret == "" { + result.OIDCConnectClientSecret = strings.TrimSpace(oidcBase.ClientSecret) + } + result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != "" + // Model fallback settings result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022") @@ -1323,6 +1505,282 @@ func (s *SettingService) SetOverloadCooldownSettings(ctx context.Context, settin return s.settingRepo.Set(ctx, SettingKeyOverloadCooldownSettings, string(data)) } +// GetOIDCConnectOAuthConfig 返回用于登录的“最终生效” OIDC 配置。 +// +// 优先级: +// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值 +// - 否则回退到 config.yaml/env 的值 +func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.OIDCConnectConfig, error) { + if s == nil || s.cfg == nil { + return config.OIDCConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") + } + + effective := s.cfg.OIDC + + keys := []string{ + SettingKeyOIDCConnectEnabled, + SettingKeyOIDCConnectProviderName, + SettingKeyOIDCConnectClientID, + SettingKeyOIDCConnectClientSecret, + SettingKeyOIDCConnectIssuerURL, + SettingKeyOIDCConnectDiscoveryURL, + SettingKeyOIDCConnectAuthorizeURL, + SettingKeyOIDCConnectTokenURL, + SettingKeyOIDCConnectUserInfoURL, + SettingKeyOIDCConnectJWKSURL, + SettingKeyOIDCConnectScopes, + SettingKeyOIDCConnectRedirectURL, + SettingKeyOIDCConnectFrontendRedirectURL, + SettingKeyOIDCConnectTokenAuthMethod, + SettingKeyOIDCConnectUsePKCE, + SettingKeyOIDCConnectValidateIDToken, + SettingKeyOIDCConnectAllowedSigningAlgs, + SettingKeyOIDCConnectClockSkewSeconds, + SettingKeyOIDCConnectRequireEmailVerified, + SettingKeyOIDCConnectUserInfoEmailPath, + SettingKeyOIDCConnectUserInfoIDPath, + SettingKeyOIDCConnectUserInfoUsernamePath, + } + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return config.OIDCConnectConfig{}, fmt.Errorf("get oidc connect settings: %w", err) + } + + if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok { + effective.Enabled = raw == "true" + } + if v, ok := settings[SettingKeyOIDCConnectProviderName]; ok && strings.TrimSpace(v) != "" { + effective.ProviderName = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectClientID]; ok && strings.TrimSpace(v) != "" { + effective.ClientID = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectClientSecret]; ok && strings.TrimSpace(v) != "" { + effective.ClientSecret = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectIssuerURL]; ok && strings.TrimSpace(v) != "" { + effective.IssuerURL = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectDiscoveryURL]; ok && strings.TrimSpace(v) != "" { + effective.DiscoveryURL = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectAuthorizeURL]; ok && strings.TrimSpace(v) != "" { + effective.AuthorizeURL = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectTokenURL]; ok && strings.TrimSpace(v) != "" { + effective.TokenURL = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectUserInfoURL]; ok && strings.TrimSpace(v) != "" { + effective.UserInfoURL = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectJWKSURL]; ok && strings.TrimSpace(v) != "" { + effective.JWKSURL = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectScopes]; ok && strings.TrimSpace(v) != "" { + effective.Scopes = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { + effective.RedirectURL = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectFrontendRedirectURL]; ok && strings.TrimSpace(v) != "" { + effective.FrontendRedirectURL = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectTokenAuthMethod]; ok && strings.TrimSpace(v) != "" { + effective.TokenAuthMethod = strings.ToLower(strings.TrimSpace(v)) + } + if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok { + effective.UsePKCE = raw == "true" + } + if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { + effective.ValidateIDToken = raw == "true" + } + if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { + effective.AllowedSigningAlgs = strings.TrimSpace(v) + } + if raw, ok := settings[SettingKeyOIDCConnectClockSkewSeconds]; ok && strings.TrimSpace(raw) != "" { + if parsed, parseErr := strconv.Atoi(strings.TrimSpace(raw)); parseErr == nil { + effective.ClockSkewSeconds = parsed + } + } + if raw, ok := settings[SettingKeyOIDCConnectRequireEmailVerified]; ok { + effective.RequireEmailVerified = raw == "true" + } + if v, ok := settings[SettingKeyOIDCConnectUserInfoEmailPath]; ok { + effective.UserInfoEmailPath = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectUserInfoIDPath]; ok { + effective.UserInfoIDPath = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyOIDCConnectUserInfoUsernamePath]; ok { + effective.UserInfoUsernamePath = strings.TrimSpace(v) + } + + if !effective.Enabled { + return config.OIDCConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") + } + if strings.TrimSpace(effective.ProviderName) == "" { + effective.ProviderName = "OIDC" + } + if strings.TrimSpace(effective.ClientID) == "" { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured") + } + if strings.TrimSpace(effective.IssuerURL) == "" { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth issuer url not configured") + } + if strings.TrimSpace(effective.RedirectURL) == "" { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured") + } + if strings.TrimSpace(effective.FrontendRedirectURL) == "" { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url not configured") + } + if !scopesContainOpenID(effective.Scopes) { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth scopes must contain openid") + } + if effective.ClockSkewSeconds < 0 || effective.ClockSkewSeconds > 600 { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth clock skew must be between 0 and 600") + } + + if err := config.ValidateAbsoluteHTTPURL(effective.IssuerURL); err != nil { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth issuer url invalid") + } + + discoveryURL := strings.TrimSpace(effective.DiscoveryURL) + if discoveryURL == "" { + discoveryURL = oidcDefaultDiscoveryURL(effective.IssuerURL) + effective.DiscoveryURL = discoveryURL + } + if discoveryURL != "" { + if err := config.ValidateAbsoluteHTTPURL(discoveryURL); err != nil { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth discovery url invalid") + } + } + + needsDiscovery := strings.TrimSpace(effective.AuthorizeURL) == "" || + strings.TrimSpace(effective.TokenURL) == "" || + (effective.ValidateIDToken && strings.TrimSpace(effective.JWKSURL) == "") + if needsDiscovery && discoveryURL != "" { + metadata, resolveErr := oidcResolveProviderMetadata(ctx, discoveryURL) + if resolveErr != nil { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth discovery resolve failed").WithCause(resolveErr) + } + if strings.TrimSpace(effective.AuthorizeURL) == "" { + effective.AuthorizeURL = strings.TrimSpace(metadata.AuthorizationEndpoint) + } + if strings.TrimSpace(effective.TokenURL) == "" { + effective.TokenURL = strings.TrimSpace(metadata.TokenEndpoint) + } + if strings.TrimSpace(effective.UserInfoURL) == "" { + effective.UserInfoURL = strings.TrimSpace(metadata.UserInfoEndpoint) + } + if strings.TrimSpace(effective.JWKSURL) == "" { + effective.JWKSURL = strings.TrimSpace(metadata.JWKSURI) + } + } + + if strings.TrimSpace(effective.AuthorizeURL) == "" { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url not configured") + } + if strings.TrimSpace(effective.TokenURL) == "" { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url not configured") + } + if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url invalid") + } + if v := strings.TrimSpace(effective.UserInfoURL); v != "" { + if err := config.ValidateAbsoluteHTTPURL(v); err != nil { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url invalid") + } + } + if effective.ValidateIDToken { + if strings.TrimSpace(effective.JWKSURL) == "" { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth jwks url not configured") + } + if strings.TrimSpace(effective.AllowedSigningAlgs) == "" { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth signing algs not configured") + } + } + if v := strings.TrimSpace(effective.JWKSURL); v != "" { + if err := config.ValidateAbsoluteHTTPURL(v); err != nil { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth jwks url invalid") + } + } + if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url invalid") + } + if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid") + } + + method := strings.ToLower(strings.TrimSpace(effective.TokenAuthMethod)) + switch method { + case "", "client_secret_post", "client_secret_basic": + if strings.TrimSpace(effective.ClientSecret) == "" { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") + } + case "none": + if !effective.UsePKCE { + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") + } + default: + return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") + } + + return effective, nil +} + +func scopesContainOpenID(scopes string) bool { + for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) { + if scope == "openid" { + return true + } + } + return false +} + +type oidcProviderMetadata struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserInfoEndpoint string `json:"userinfo_endpoint"` + JWKSURI string `json:"jwks_uri"` +} + +func oidcDefaultDiscoveryURL(issuerURL string) string { + issuerURL = strings.TrimSpace(issuerURL) + if issuerURL == "" { + return "" + } + return strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration" +} + +func oidcResolveProviderMetadata(ctx context.Context, discoveryURL string) (*oidcProviderMetadata, error) { + discoveryURL = strings.TrimSpace(discoveryURL) + if discoveryURL == "" { + return nil, fmt.Errorf("discovery url is empty") + } + + resp, err := req.C(). + SetTimeout(15*time.Second). + R(). + SetContext(ctx). + SetHeader("Accept", "application/json"). + Get(discoveryURL) + if err != nil { + return nil, fmt.Errorf("request discovery document: %w", err) + } + if !resp.IsSuccessState() { + return nil, fmt.Errorf("discovery request failed: status=%d", resp.StatusCode) + } + + metadata := &oidcProviderMetadata{} + if err := json.Unmarshal(resp.Bytes(), metadata); err != nil { + return nil, fmt.Errorf("parse discovery document: %w", err) + } + return metadata, nil +} + // GetStreamTimeoutSettings 获取流超时处理配置 func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) { value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings) diff --git a/backend/internal/service/setting_service_oidc_config_test.go b/backend/internal/service/setting_service_oidc_config_test.go new file mode 100644 index 00000000..3809b332 --- /dev/null +++ b/backend/internal/service/setting_service_oidc_config_test.go @@ -0,0 +1,103 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type settingOIDCRepoStub struct { + values map[string]string +} + +func (s *settingOIDCRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *settingOIDCRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingOIDCRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingOIDCRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingOIDCRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingOIDCRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingOIDCRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestGetOIDCConnectOAuthConfig_ResolvesEndpointsFromIssuerDiscovery(t *testing.T) { + var discoveryHits int + var baseURL string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/issuer/.well-known/openid-configuration" { + http.NotFound(w, r) + return + } + discoveryHits++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(fmt.Sprintf(`{ + "authorization_endpoint":"%s/issuer/protocol/openid-connect/auth", + "token_endpoint":"%s/issuer/protocol/openid-connect/token", + "userinfo_endpoint":"%s/issuer/protocol/openid-connect/userinfo", + "jwks_uri":"%s/issuer/protocol/openid-connect/certs" + }`, baseURL, baseURL, baseURL, baseURL))) + })) + defer srv.Close() + baseURL = srv.URL + + cfg := &config.Config{ + OIDC: config.OIDCConnectConfig{ + Enabled: true, + ProviderName: "OIDC", + ClientID: "oidc-client", + ClientSecret: "oidc-secret", + IssuerURL: srv.URL + "/issuer", + RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback", + FrontendRedirectURL: "/auth/oidc/callback", + Scopes: "openid email profile", + TokenAuthMethod: "client_secret_post", + ValidateIDToken: true, + AllowedSigningAlgs: "RS256", + ClockSkewSeconds: 120, + }, + } + + repo := &settingOIDCRepoStub{values: map[string]string{}} + svc := NewSettingService(repo, cfg) + + got, err := svc.GetOIDCConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.Equal(t, 1, discoveryHits) + require.Equal(t, srv.URL+"/issuer/.well-known/openid-configuration", got.DiscoveryURL) + require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/auth", got.AuthorizeURL) + require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/token", got.TokenURL) + require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/userinfo", got.UserInfoURL) + require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/certs", got.JWKSURL) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index fedb3f2f..80932e9d 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -31,6 +31,31 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool LinuxDoConnectRedirectURL string + // Generic OIDC OAuth 登录 + OIDCConnectEnabled bool + OIDCConnectProviderName string + OIDCConnectClientID string + OIDCConnectClientSecret string + OIDCConnectClientSecretConfigured bool + OIDCConnectIssuerURL string + OIDCConnectDiscoveryURL string + OIDCConnectAuthorizeURL string + OIDCConnectTokenURL string + OIDCConnectUserInfoURL string + OIDCConnectJWKSURL string + OIDCConnectScopes string + OIDCConnectRedirectURL string + OIDCConnectFrontendRedirectURL string + OIDCConnectTokenAuthMethod string + OIDCConnectUsePKCE bool + OIDCConnectValidateIDToken bool + OIDCConnectAllowedSigningAlgs string + OIDCConnectClockSkewSeconds int + OIDCConnectRequireEmailVerified bool + OIDCConnectUserInfoEmailPath string + OIDCConnectUserInfoIDPath string + OIDCConnectUserInfoUsernamePath string + SiteName string SiteLogo string SiteSubtitle string @@ -110,9 +135,11 @@ type PublicSettings struct { CustomMenuItems string // JSON array of custom menu items CustomEndpoints string // JSON array of custom endpoints - LinuxDoOAuthEnabled bool - BackendModeEnabled bool - Version string + LinuxDoOAuthEnabled bool + BackendModeEnabled bool + OIDCOAuthEnabled bool + OIDCOAuthProviderName string + Version string } // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 8f60acd5..cd6e7e3f 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -820,6 +820,46 @@ linuxdo_connect: userinfo_id_path: "" userinfo_username_path: "" +# ============================================================================= +# Generic OIDC OAuth Login (SSO) +# 通用 OIDC OAuth 登录(用于 Sub2API 用户登录) +# ============================================================================= +oidc_connect: + enabled: false + provider_name: "OIDC" + client_id: "" + client_secret: "" + # 例如: "https://keycloak.example.com/realms/myrealm" + issuer_url: "" + # 可选: OIDC Discovery URL。为空时可手动填写 authorize/token/userinfo/jwks + discovery_url: "" + authorize_url: "" + token_url: "" + # 可选(仅补充 email/username,不用于 sub 可信绑定) + userinfo_url: "" + # validate_id_token=true 时必填 + jwks_url: "" + scopes: "openid email profile" + # 示例: "https://your-domain.com/api/v1/auth/oauth/oidc/callback" + redirect_url: "" + # 安全提示: + # - 建议使用同源相对路径(以 / 开头),避免把 token 重定向到意外的第三方域名 + # - 该地址不应包含 #fragment(本实现使用 URL fragment 传递 access_token) + frontend_redirect_url: "/auth/oidc/callback" + token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none + # 注意:当 token_auth_method=none(public client)时,必须启用 PKCE + use_pkce: false + # 开启后强制校验 id_token 的签名和 claims(推荐) + validate_id_token: true + allowed_signing_algs: "RS256,ES256,PS256" + # 允许的时钟偏移(秒) + clock_skew_seconds: 120 + # 若 Provider 返回 email_verified=false,是否拒绝登录 + require_email_verified: false + userinfo_email_path: "" + userinfo_id_path: "" + userinfo_username_path: "" + # ============================================================================= # Default Settings # 默认设置 diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index b7ee6be5..9916f1ab 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -62,6 +62,30 @@ export interface SystemSettings { linuxdo_connect_client_secret_configured: boolean linuxdo_connect_redirect_url: string + // Generic OIDC OAuth settings + oidc_connect_enabled: boolean + oidc_connect_provider_name: string + oidc_connect_client_id: string + oidc_connect_client_secret_configured: boolean + oidc_connect_issuer_url: string + oidc_connect_discovery_url: string + oidc_connect_authorize_url: string + oidc_connect_token_url: string + oidc_connect_userinfo_url: string + oidc_connect_jwks_url: string + oidc_connect_scopes: string + oidc_connect_redirect_url: string + oidc_connect_frontend_redirect_url: string + oidc_connect_token_auth_method: string + oidc_connect_use_pkce: boolean + oidc_connect_validate_id_token: boolean + oidc_connect_allowed_signing_algs: string + oidc_connect_clock_skew_seconds: number + oidc_connect_require_email_verified: boolean + oidc_connect_userinfo_email_path: string + oidc_connect_userinfo_id_path: string + oidc_connect_userinfo_username_path: string + // Model fallback configuration enable_model_fallback: boolean fallback_model_anthropic: string @@ -131,6 +155,28 @@ export interface UpdateSettingsRequest { linuxdo_connect_client_id?: string linuxdo_connect_client_secret?: string linuxdo_connect_redirect_url?: string + oidc_connect_enabled?: boolean + oidc_connect_provider_name?: string + oidc_connect_client_id?: string + oidc_connect_client_secret?: string + oidc_connect_issuer_url?: string + oidc_connect_discovery_url?: string + oidc_connect_authorize_url?: string + oidc_connect_token_url?: string + oidc_connect_userinfo_url?: string + oidc_connect_jwks_url?: string + oidc_connect_scopes?: string + oidc_connect_redirect_url?: string + oidc_connect_frontend_redirect_url?: string + oidc_connect_token_auth_method?: string + oidc_connect_use_pkce?: boolean + oidc_connect_validate_id_token?: boolean + oidc_connect_allowed_signing_algs?: string + oidc_connect_clock_skew_seconds?: number + oidc_connect_require_email_verified?: boolean + oidc_connect_userinfo_email_path?: string + oidc_connect_userinfo_id_path?: string + oidc_connect_userinfo_username_path?: string enable_model_fallback?: boolean fallback_model_anthropic?: string fallback_model_openai?: string diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index c5e1f35d..837c4f4c 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -357,6 +357,28 @@ export async function completeLinuxDoOAuthRegistration( return data } +/** + * Complete OIDC OAuth registration by supplying an invitation code + * @param pendingOAuthToken - Short-lived JWT from the OAuth callback + * @param invitationCode - Invitation code entered by the user + * @returns Token pair on success + */ +export async function completeOIDCOAuthRegistration( + pendingOAuthToken: string, + invitationCode: string +): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> { + const { data } = await apiClient.post<{ + access_token: string + refresh_token: string + expires_in: number + token_type: string + }>('/auth/oauth/oidc/complete-registration', { + pending_oauth_token: pendingOAuthToken, + invitation_code: invitationCode + }) + return data +} + export const authAPI = { login, login2FA, @@ -380,7 +402,8 @@ export const authAPI = { resetPassword, refreshToken, revokeAllSessions, - completeLinuxDoOAuthRegistration + completeLinuxDoOAuthRegistration, + completeOIDCOAuthRegistration } export default authAPI diff --git a/frontend/src/components/auth/LinuxDoOAuthSection.vue b/frontend/src/components/auth/LinuxDoOAuthSection.vue index 8012b101..c740d06f 100644 --- a/frontend/src/components/auth/LinuxDoOAuthSection.vue +++ b/frontend/src/components/auth/LinuxDoOAuthSection.vue @@ -29,10 +29,10 @@ {{ t('auth.linuxdo.signIn') }} -
+
- {{ t('auth.linuxdo.orContinue') }} + {{ t('auth.oauthOrContinue') }}
@@ -43,9 +43,12 @@ import { useRoute } from 'vue-router' import { useI18n } from 'vue-i18n' -defineProps<{ +withDefaults(defineProps<{ disabled?: boolean -}>() + showDivider?: boolean +}>(), { + showDivider: true +}) const route = useRoute() const { t } = useI18n() @@ -58,4 +61,3 @@ function startLogin(): void { window.location.href = startURL } - diff --git a/frontend/src/components/auth/OidcOAuthSection.vue b/frontend/src/components/auth/OidcOAuthSection.vue new file mode 100644 index 00000000..f7cc7fa3 --- /dev/null +++ b/frontend/src/components/auth/OidcOAuthSection.vue @@ -0,0 +1,53 @@ + + + diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index d3b16d4a..475d8f33 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -428,6 +428,7 @@ export default { invitationCodeInvalid: 'Invalid or used invitation code', invitationCodeValidating: 'Validating invitation code...', invitationCodeInvalidCannotRegister: 'Invalid invitation code. Please check and try again', + oauthOrContinue: 'or continue with email', linuxdo: { signIn: 'Continue with Linux.do', orContinue: 'or continue with email', @@ -442,6 +443,20 @@ export default { completing: 'Completing registration…', completeRegistrationFailed: 'Registration failed. Please check your invitation code and try again.' }, + oidc: { + signIn: 'Continue with {providerName}', + callbackTitle: 'Signing you in with {providerName}', + callbackProcessing: 'Completing login with {providerName}, please wait...', + callbackHint: 'If you are not redirected automatically, go back to the login page and try again.', + callbackMissingToken: 'Missing login token, please try again.', + backToLogin: 'Back to Login', + invitationRequired: + 'This {providerName} account is not yet registered. The site requires an invitation code — please enter one to complete registration.', + invalidPendingToken: 'The registration token has expired. Please sign in again.', + completeRegistration: 'Complete Registration', + completing: 'Completing registration…', + completeRegistrationFailed: 'Registration failed. Please check your invitation code and try again.' + }, oauth: { code: 'Code', state: 'State', @@ -4227,6 +4242,57 @@ export default { quickSetCopy: 'Generate & Copy (current site)', redirectUrlSetAndCopied: 'Redirect URL generated and copied to clipboard' }, + oidc: { + title: 'OIDC Login', + description: 'Configure a standard OIDC provider (for example Keycloak)', + enable: 'Enable OIDC Login', + enableHint: 'Show OIDC login on the login/register pages', + providerName: 'Provider Name', + providerNamePlaceholder: 'for example Keycloak', + clientId: 'Client ID', + clientIdPlaceholder: 'OIDC client id', + clientSecret: 'Client Secret', + clientSecretPlaceholder: '********', + clientSecretHint: 'Used by backend to exchange tokens (keep it secret)', + clientSecretConfiguredPlaceholder: '********', + clientSecretConfiguredHint: 'Secret configured. Leave empty to keep the current value.', + issuerUrl: 'Issuer URL', + issuerUrlPlaceholder: 'https://id.example.com/realms/main', + discoveryUrl: 'Discovery URL', + discoveryUrlPlaceholder: 'Optional, leave empty to auto-derive from issuer', + authorizeUrl: 'Authorize URL', + authorizeUrlPlaceholder: 'Optional, can be discovered automatically', + tokenUrl: 'Token URL', + tokenUrlPlaceholder: 'Optional, can be discovered automatically', + userinfoUrl: 'UserInfo URL', + userinfoUrlPlaceholder: 'Optional, can be discovered automatically', + jwksUrl: 'JWKS URL', + jwksUrlPlaceholder: 'Optional, required when strict ID token validation is enabled', + scopes: 'Scopes', + scopesPlaceholder: 'openid email profile', + scopesHint: 'Must include openid', + redirectUrl: 'Backend Redirect URL', + redirectUrlPlaceholder: 'https://your-domain.com/api/v1/auth/oauth/oidc/callback', + redirectUrlHint: 'Must match the callback URL configured in the OIDC provider', + quickSetCopy: 'Generate & Copy (current site)', + redirectUrlSetAndCopied: 'Redirect URL generated and copied to clipboard', + frontendRedirectUrl: 'Frontend Callback Path', + frontendRedirectUrlPlaceholder: '/auth/oidc/callback', + frontendRedirectUrlHint: 'Frontend route used after backend callback', + tokenAuthMethod: 'Token Auth Method', + clockSkewSeconds: 'Clock Skew (seconds)', + allowedSigningAlgs: 'Allowed Signing Algs', + allowedSigningAlgsPlaceholder: 'RS256,ES256,PS256', + usePkce: 'Use PKCE', + validateIdToken: 'Validate ID Token', + requireEmailVerified: 'Require Email Verified', + userinfoEmailPath: 'UserInfo Email Path', + userinfoEmailPathPlaceholder: 'for example data.email', + userinfoIdPath: 'UserInfo ID Path', + userinfoIdPathPlaceholder: 'for example data.id', + userinfoUsernamePath: 'UserInfo Username Path', + userinfoUsernamePathPlaceholder: 'for example data.username' + }, defaults: { title: 'Default User Settings', description: 'Default values for new users', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index fcaaf5ab..eee2725c 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -427,6 +427,7 @@ export default { invitationCodeInvalid: '邀请码无效或已被使用', invitationCodeValidating: '正在验证邀请码...', invitationCodeInvalidCannotRegister: '邀请码无效,请检查后重试', + oauthOrContinue: '或使用邮箱密码继续', linuxdo: { signIn: '使用 Linux.do 登录', orContinue: '或使用邮箱密码继续', @@ -441,6 +442,19 @@ export default { completing: '正在完成注册...', completeRegistrationFailed: '注册失败,请检查邀请码后重试。' }, + oidc: { + signIn: '使用 {providerName} 登录', + callbackTitle: '正在完成 {providerName} 登录', + callbackProcessing: '正在验证 {providerName} 登录信息,请稍候...', + callbackHint: '如果页面未自动跳转,请返回登录页重试。', + callbackMissingToken: '登录信息缺失,请返回重试。', + backToLogin: '返回登录', + invitationRequired: '该 {providerName} 账号尚未注册,站点已开启邀请码注册,请输入邀请码以完成注册。', + invalidPendingToken: '注册凭证已失效,请重新登录。', + completeRegistration: '完成注册', + completing: '正在完成注册...', + completeRegistrationFailed: '注册失败,请检查邀请码后重试。' + }, oauth: { code: '授权码', state: '状态', @@ -4393,6 +4407,57 @@ export default { quickSetCopy: '使用当前站点生成并复制', redirectUrlSetAndCopied: '已使用当前站点生成回调地址并复制到剪贴板' }, + oidc: { + title: 'OIDC 登录', + description: '配置标准 OIDC Provider(例如 Keycloak)', + enable: '启用 OIDC 登录', + enableHint: '在登录/注册页面显示 OIDC 登录入口', + providerName: 'Provider 名称', + providerNamePlaceholder: '例如 Keycloak', + clientId: 'Client ID', + clientIdPlaceholder: 'OIDC client id', + clientSecret: 'Client Secret', + clientSecretPlaceholder: '********', + clientSecretHint: '用于后端交换 token(请保密)', + clientSecretConfiguredPlaceholder: '********', + clientSecretConfiguredHint: '密钥已配置,留空以保留当前值。', + issuerUrl: 'Issuer URL', + issuerUrlPlaceholder: 'https://id.example.com/realms/main', + discoveryUrl: 'Discovery URL', + discoveryUrlPlaceholder: '可选,留空将基于 issuer 自动推导', + authorizeUrl: 'Authorize URL', + authorizeUrlPlaceholder: '可选,可通过 discovery 自动获取', + tokenUrl: 'Token URL', + tokenUrlPlaceholder: '可选,可通过 discovery 自动获取', + userinfoUrl: 'UserInfo URL', + userinfoUrlPlaceholder: '可选,可通过 discovery 自动获取', + jwksUrl: 'JWKS URL', + jwksUrlPlaceholder: '可选;启用严格 ID Token 校验时必填', + scopes: 'Scopes', + scopesPlaceholder: 'openid email profile', + scopesHint: '必须包含 openid', + redirectUrl: '后端回调地址(Redirect URL)', + redirectUrlPlaceholder: 'https://your-domain.com/api/v1/auth/oauth/oidc/callback', + redirectUrlHint: '必须与 OIDC Provider 中配置的回调地址一致', + quickSetCopy: '使用当前站点生成并复制', + redirectUrlSetAndCopied: '已使用当前站点生成回调地址并复制到剪贴板', + frontendRedirectUrl: '前端回调路径', + frontendRedirectUrlPlaceholder: '/auth/oidc/callback', + frontendRedirectUrlHint: '后端回调完成后重定向到此前端路径', + tokenAuthMethod: 'Token 鉴权方式', + clockSkewSeconds: '时钟偏移(秒)', + allowedSigningAlgs: '允许的签名算法', + allowedSigningAlgsPlaceholder: 'RS256,ES256,PS256', + usePkce: '启用 PKCE', + validateIdToken: '校验 ID Token', + requireEmailVerified: '要求邮箱已验证', + userinfoEmailPath: 'UserInfo 邮箱字段路径', + userinfoEmailPathPlaceholder: '例如 data.email', + userinfoIdPath: 'UserInfo ID 字段路径', + userinfoIdPathPlaceholder: '例如 data.id', + userinfoUsernamePath: 'UserInfo 用户名字段路径', + userinfoUsernamePathPlaceholder: '例如 data.username' + }, defaults: { title: '用户默认设置', description: '新用户的默认值', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 6faf6f59..9bc6115f 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -83,6 +83,15 @@ const routes: RouteRecordRaw[] = [ title: 'LinuxDo OAuth Callback' } }, + { + path: '/auth/oidc/callback', + name: 'OIDCOAuthCallback', + component: () => import('@/views/auth/OidcCallbackView.vue'), + meta: { + requiresAuth: false, + title: 'OIDC OAuth Callback' + } + }, { path: '/forgot-password', name: 'ForgotPassword', diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index 24057136..edeb3bb1 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -332,6 +332,9 @@ export const useAppStore = defineStore('app', () => { custom_menu_items: [], custom_endpoints: [], linuxdo_oauth_enabled: false, + oidc_oauth_enabled: false, + oidc_oauth_provider_name: 'OIDC', + sora_client_enabled: false, backend_mode_enabled: false, version: siteVersion.value } diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 580126c8..b4f8f79a 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -109,6 +109,9 @@ export interface PublicSettings { custom_menu_items: CustomMenuItem[] custom_endpoints: CustomEndpoint[] linuxdo_oauth_enabled: boolean + oidc_oauth_enabled: boolean + oidc_oauth_provider_name: string + sora_client_enabled: boolean backend_mode_enabled: boolean version: string } diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index f43140ab..dd934235 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -1124,7 +1124,327 @@
- + + +
+
+

+ {{ t('admin.settings.oidc.title') }} +

+

+ {{ t('admin.settings.oidc.description') }} +

+
+
+
+
+ +

+ {{ t('admin.settings.oidc.enableHint') }} +

+
+ +
+ +
+
+
+ + +
+ +
+ + +
+ +
+ + +

+ {{ + form.oidc_connect_client_secret_configured + ? t('admin.settings.oidc.clientSecretConfiguredHint') + : t('admin.settings.oidc.clientSecretHint') + }} +

+
+
+ +
+
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+
+ +
+
+ + +

+ {{ t('admin.settings.oidc.scopesHint') }} +

+
+ +
+ + +
+ + + {{ oidcRedirectUrlSuggestion }} + +
+

+ {{ t('admin.settings.oidc.redirectUrlHint') }} +

+
+ +
+ + +

+ {{ t('admin.settings.oidc.frontendRedirectUrlHint') }} +

+
+
+ +
+
+ + +
+ +
+ + +
+ +
+ + +
+
+ +
+
+
+ +
+ +
+ +
+
+ +
+ +
+ +
+
+ +
+ +
+
+ +
+
+ + +
+ +
+ + +
+ +
+ + +
+
+
+
+
+
@@ -2193,6 +2513,7 @@ type SettingsForm = SystemSettings & { smtp_password: string turnstile_secret_key: string linuxdo_connect_client_secret: string + oidc_connect_client_secret: string } const form = reactive({ @@ -2240,6 +2561,30 @@ const form = reactive({ linuxdo_connect_client_secret: '', linuxdo_connect_client_secret_configured: false, linuxdo_connect_redirect_url: '', + // Generic OIDC OAuth 登录 + oidc_connect_enabled: false, + oidc_connect_provider_name: 'OIDC', + oidc_connect_client_id: '', + oidc_connect_client_secret: '', + oidc_connect_client_secret_configured: false, + oidc_connect_issuer_url: '', + oidc_connect_discovery_url: '', + oidc_connect_authorize_url: '', + oidc_connect_token_url: '', + oidc_connect_userinfo_url: '', + oidc_connect_jwks_url: '', + oidc_connect_scopes: 'openid email profile', + oidc_connect_redirect_url: '', + oidc_connect_frontend_redirect_url: '/auth/oidc/callback', + oidc_connect_token_auth_method: 'client_secret_post', + oidc_connect_use_pkce: false, + oidc_connect_validate_id_token: true, + oidc_connect_allowed_signing_algs: 'RS256,ES256,PS256', + oidc_connect_clock_skew_seconds: 120, + oidc_connect_require_email_verified: false, + oidc_connect_userinfo_email_path: '', + oidc_connect_userinfo_id_path: '', + oidc_connect_userinfo_username_path: '', // Model fallback enable_model_fallback: false, fallback_model_anthropic: 'claude-3-5-sonnet-20241022', @@ -2360,6 +2705,21 @@ async function setAndCopyLinuxdoRedirectUrl() { await copyToClipboard(url, t('admin.settings.linuxdo.redirectUrlSetAndCopied')) } +const oidcRedirectUrlSuggestion = computed(() => { + if (typeof window === 'undefined') return '' + const origin = + window.location.origin || `${window.location.protocol}//${window.location.host}` + return `${origin}/api/v1/auth/oauth/oidc/callback` +}) + +async function setAndCopyOIDCRedirectUrl() { + const url = oidcRedirectUrlSuggestion.value + if (!url) return + + form.oidc_connect_redirect_url = url + await copyToClipboard(url, t('admin.settings.oidc.redirectUrlSetAndCopied')) +} + // Custom menu item management function addMenuItem() { form.custom_menu_items.push({ @@ -2425,6 +2785,7 @@ async function loadSettings() { smtpPasswordManuallyEdited.value = false form.turnstile_secret_key = '' form.linuxdo_connect_client_secret = '' + form.oidc_connect_client_secret = '' } catch (error: any) { loadFailed.value = true appStore.showError( @@ -2559,6 +2920,28 @@ async function saveSettings() { linuxdo_connect_client_id: form.linuxdo_connect_client_id, linuxdo_connect_client_secret: form.linuxdo_connect_client_secret || undefined, linuxdo_connect_redirect_url: form.linuxdo_connect_redirect_url, + oidc_connect_enabled: form.oidc_connect_enabled, + oidc_connect_provider_name: form.oidc_connect_provider_name, + oidc_connect_client_id: form.oidc_connect_client_id, + oidc_connect_client_secret: form.oidc_connect_client_secret || undefined, + oidc_connect_issuer_url: form.oidc_connect_issuer_url, + oidc_connect_discovery_url: form.oidc_connect_discovery_url, + oidc_connect_authorize_url: form.oidc_connect_authorize_url, + oidc_connect_token_url: form.oidc_connect_token_url, + oidc_connect_userinfo_url: form.oidc_connect_userinfo_url, + oidc_connect_jwks_url: form.oidc_connect_jwks_url, + oidc_connect_scopes: form.oidc_connect_scopes, + oidc_connect_redirect_url: form.oidc_connect_redirect_url, + oidc_connect_frontend_redirect_url: form.oidc_connect_frontend_redirect_url, + oidc_connect_token_auth_method: form.oidc_connect_token_auth_method, + oidc_connect_use_pkce: form.oidc_connect_use_pkce, + oidc_connect_validate_id_token: form.oidc_connect_validate_id_token, + oidc_connect_allowed_signing_algs: form.oidc_connect_allowed_signing_algs, + oidc_connect_clock_skew_seconds: form.oidc_connect_clock_skew_seconds, + oidc_connect_require_email_verified: form.oidc_connect_require_email_verified, + oidc_connect_userinfo_email_path: form.oidc_connect_userinfo_email_path, + oidc_connect_userinfo_id_path: form.oidc_connect_userinfo_id_path, + oidc_connect_userinfo_username_path: form.oidc_connect_userinfo_username_path, enable_model_fallback: form.enable_model_fallback, fallback_model_anthropic: form.fallback_model_anthropic, fallback_model_openai: form.fallback_model_openai, @@ -2583,6 +2966,7 @@ async function saveSettings() { smtpPasswordManuallyEdited.value = false form.turnstile_secret_key = '' form.linuxdo_connect_client_secret = '' + form.oidc_connect_client_secret = '' // Refresh cached settings so sidebar/header update immediately await appStore.fetchPublicSettings(true) await adminSettingsStore.fetch(true) diff --git a/frontend/src/views/auth/LoginView.vue b/frontend/src/views/auth/LoginView.vue index 73d2474c..70b64e3f 100644 --- a/frontend/src/views/auth/LoginView.vue +++ b/frontend/src/views/auth/LoginView.vue @@ -11,8 +11,26 @@

- - +
+ + +
+
+ + {{ t('auth.oauthOrContinue') }} + +
+
+
@@ -181,6 +199,7 @@ import { useRouter } from 'vue-router' import { useI18n } from 'vue-i18n' import { AuthLayout } from '@/components/layout' import LinuxDoOAuthSection from '@/components/auth/LinuxDoOAuthSection.vue' +import OidcOAuthSection from '@/components/auth/OidcOAuthSection.vue' import TotpLoginModal from '@/components/auth/TotpLoginModal.vue' import Icon from '@/components/icons/Icon.vue' import TurnstileWidget from '@/components/TurnstileWidget.vue' @@ -207,6 +226,8 @@ const turnstileEnabled = ref(false) const turnstileSiteKey = ref('') const linuxdoOAuthEnabled = ref(false) const backendModeEnabled = ref(false) +const oidcOAuthEnabled = ref(false) +const oidcOAuthProviderName = ref('OIDC') const passwordResetEnabled = ref(false) // Turnstile @@ -247,6 +268,9 @@ onMounted(async () => { turnstileSiteKey.value = settings.turnstile_site_key || '' linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled backendModeEnabled.value = settings.backend_mode_enabled + oidcOAuthEnabled.value = settings.oidc_oauth_enabled + oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC' + backendModeEnabled.value = settings.backend_mode_enabled passwordResetEnabled.value = settings.password_reset_enabled } catch (error) { console.error('Failed to load public settings:', error) diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue new file mode 100644 index 00000000..a6cb6c12 --- /dev/null +++ b/frontend/src/views/auth/OidcCallbackView.vue @@ -0,0 +1,234 @@ + + + + + diff --git a/frontend/src/views/auth/RegisterView.vue b/frontend/src/views/auth/RegisterView.vue index d1b576d4..bc8b8dce 100644 --- a/frontend/src/views/auth/RegisterView.vue +++ b/frontend/src/views/auth/RegisterView.vue @@ -11,8 +11,26 @@

- - +
+ + +
+
+ + {{ t('auth.oauthOrContinue') }} + +
+
+
(false) const turnstileSiteKey = ref('') const siteName = ref('Sub2API') const linuxdoOAuthEnabled = ref(false) +const oidcOAuthEnabled = ref(false) +const oidcOAuthProviderName = ref('OIDC') const registrationEmailSuffixWhitelist = ref([]) // Turnstile @@ -376,6 +397,8 @@ onMounted(async () => { turnstileSiteKey.value = settings.turnstile_site_key || '' siteName.value = settings.site_name || 'Sub2API' linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled + oidcOAuthEnabled.value = settings.oidc_oauth_enabled + oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC' registrationEmailSuffixWhitelist.value = normalizeRegistrationEmailSuffixWhitelist( settings.registration_email_suffix_whitelist || [] ) -- GitLab From 8e1a7bdfff2663f50ebe46013ee241683aee282a Mon Sep 17 00:00:00 2001 From: Glorhop <1150595033@qq.com> Date: Sat, 14 Mar 2026 14:45:43 +0000 Subject: [PATCH 05/52] fix: fixed an issue where OIDC login consistently used a synthetic email address --- backend/internal/handler/auth_oidc_oauth.go | 10 +++++++++- backend/internal/handler/auth_oidc_oauth_test.go | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index f46fb850..9d24df88 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -306,7 +306,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { } identityKey := oidcIdentityKey(issuer, subject) - email := oidcSyntheticEmailFromIdentityKey(identityKey) + email := oidcSelectLoginEmail(userInfoClaims.Email, idClaims.Email, identityKey) username := firstNonEmpty( userInfoClaims.Username, idClaims.PreferredUsername, @@ -831,6 +831,14 @@ func oidcSyntheticEmailFromIdentityKey(identityKey string) string { return "oidc-" + hex.EncodeToString(sum[:16]) + service.OIDCConnectSyntheticEmailDomain } +func oidcSelectLoginEmail(userInfoEmail, idTokenEmail, identityKey string) string { + email := strings.TrimSpace(firstNonEmpty(userInfoEmail, idTokenEmail)) + if email != "" { + return email + } + return oidcSyntheticEmailFromIdentityKey(identityKey) +} + func oidcFallbackUsername(subject string) string { subject = strings.TrimSpace(subject) if subject == "" { diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index 1f50dd49..a161aa77 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -30,6 +30,20 @@ func TestOIDCSyntheticEmailStableAndDistinct(t *testing.T) { require.Contains(t, e1, "@oidc-connect.invalid") } +func TestOIDCSelectLoginEmailPrefersRealEmail(t *testing.T) { + identityKey := oidcIdentityKey("https://issuer.example.com", "subject-a") + + email := oidcSelectLoginEmail("user@example.com", "idtoken@example.com", identityKey) + require.Equal(t, "user@example.com", email) + + email = oidcSelectLoginEmail("", "idtoken@example.com", identityKey) + require.Equal(t, "idtoken@example.com", email) + + email = oidcSelectLoginEmail("", "", identityKey) + require.Contains(t, email, "@oidc-connect.invalid") + require.Equal(t, oidcSyntheticEmailFromIdentityKey(identityKey), email) +} + func TestBuildOIDCAuthorizeURLIncludesNonceAndPKCE(t *testing.T) { cfg := config.OIDCConnectConfig{ AuthorizeURL: "https://issuer.example.com/auth", -- GitLab From 311f06745a83d244f8226e947f9e5ef579ea7dd5 Mon Sep 17 00:00:00 2001 From: Glorhop <1150595033@qq.com> Date: Thu, 9 Apr 2026 02:57:00 +0000 Subject: [PATCH 06/52] chore: clean up deprecated Sora settings after rebase --- backend/internal/handler/setting_handler.go | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 8b536877..1db104c1 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -56,7 +56,6 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthProviderName: settings.OIDCOAuthProviderName, - SoraClientEnabled: settings.SoraClientEnabled, BackendModeEnabled: settings.BackendModeEnabled, Version: h.version, }) -- GitLab From 23c4d592f852ef862d4357d2823934b59233cec7 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Thu, 9 Apr 2026 12:29:28 +0800 Subject: [PATCH 07/52] =?UTF-8?q?feat(group):=20=E5=A2=9E=E5=8A=A0messages?= =?UTF-8?q?=E8=B0=83=E5=BA=A6=E6=A8=A1=E5=9E=8B=E6=98=A0=E5=B0=84=E9=85=8D?= =?UTF-8?q?=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/ent/group.go | 16 ++- backend/ent/group/group.go | 6 + backend/ent/group_create.go | 66 +++++++++++ backend/ent/group_update.go | 35 ++++++ backend/ent/migrate/schema.go | 1 + backend/ent/mutation.go | 56 ++++++++- backend/ent/runtime/runtime.go | 5 + backend/ent/schema/group.go | 4 + .../domain/openai_messages_dispatch.go | 10 ++ .../internal/handler/admin/group_handler.go | 20 ++-- backend/internal/handler/dto/mappers.go | 21 ++-- backend/internal/handler/dto/types.go | 9 +- backend/internal/repository/group_repo.go | 6 +- backend/internal/service/admin_service.go | 24 ++-- .../service/admin_service_group_test.go | 110 ++++++++++++++++++ backend/internal/service/group.go | 13 ++- .../service/openai_messages_dispatch.go | 100 ++++++++++++++++ .../service/openai_messages_dispatch_test.go | 27 +++++ ...d_group_messages_dispatch_model_config.sql | 2 + 19 files changed, 495 insertions(+), 36 deletions(-) create mode 100644 backend/internal/domain/openai_messages_dispatch.go create mode 100644 backend/internal/service/openai_messages_dispatch.go create mode 100644 backend/internal/service/openai_messages_dispatch_test.go create mode 100644 backend/migrations/091_add_group_messages_dispatch_model_config.sql diff --git a/backend/ent/group.go b/backend/ent/group.go index b15ac15d..f10b50c3 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -11,6 +11,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/internal/domain" ) // Group is the model entity for the Group schema. @@ -76,6 +77,8 @@ type Group struct { RequirePrivacySet bool `json:"require_privacy_set,omitempty"` // 默认映射模型 ID,当账号级映射找不到时使用此值 DefaultMappedModel string `json:"default_mapped_model,omitempty"` + // OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型 + MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -182,7 +185,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case group.FieldModelRouting, group.FieldSupportedModelScopes: + case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig: values[i] = new([]byte) case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: values[i] = new(sql.NullBool) @@ -403,6 +406,14 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.DefaultMappedModel = value.String } + case group.FieldMessagesDispatchModelConfig: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field messages_dispatch_model_config", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.MessagesDispatchModelConfig); err != nil { + return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err) + } + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -585,6 +596,9 @@ func (_m *Group) String() string { builder.WriteString(", ") builder.WriteString("default_mapped_model=") builder.WriteString(_m.DefaultMappedModel) + builder.WriteString(", ") + builder.WriteString("messages_dispatch_model_config=") + builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 21a7c2cb..b1371630 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -8,6 +8,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/internal/domain" ) const ( @@ -73,6 +74,8 @@ const ( FieldRequirePrivacySet = "require_privacy_set" // FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database. FieldDefaultMappedModel = "default_mapped_model" + // FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database. + FieldMessagesDispatchModelConfig = "messages_dispatch_model_config" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -177,6 +180,7 @@ var Columns = []string{ FieldRequireOauthOnly, FieldRequirePrivacySet, FieldDefaultMappedModel, + FieldMessagesDispatchModelConfig, } var ( @@ -252,6 +256,8 @@ var ( DefaultDefaultMappedModel string // DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. DefaultMappedModelValidator func(string) error + // DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field. + DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig ) // OrderOption defines the ordering options for the Group queries. diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index a8c30b18..f412fa40 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -18,6 +18,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" + "github.com/Wei-Shaw/sub2api/internal/domain" ) // GroupCreate is the builder for creating a Group entity. @@ -410,6 +411,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate { return _c } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (_c *GroupCreate) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupCreate { + _c.mutation.SetMessagesDispatchModelConfig(v) + return _c +} + +// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil. +func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupCreate { + if v != nil { + _c.SetMessagesDispatchModelConfig(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -611,6 +626,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultMappedModel _c.mutation.SetDefaultMappedModel(v) } + if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok { + v := group.DefaultMessagesDispatchModelConfig + _c.mutation.SetMessagesDispatchModelConfig(v) + } return nil } @@ -695,6 +714,9 @@ func (_c *GroupCreate) check() error { return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)} } } + if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok { + return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)} + } return nil } @@ -838,6 +860,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) _node.DefaultMappedModel = value } + if value, ok := _c.mutation.MessagesDispatchModelConfig(); ok { + _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) + _node.MessagesDispatchModelConfig = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1462,6 +1488,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert { return u } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (u *GroupUpsert) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsert { + u.Set(group.FieldMessagesDispatchModelConfig, v) + return u +} + +// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create. +func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert { + u.SetExcluded(group.FieldMessagesDispatchModelConfig) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2053,6 +2091,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne { }) } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (u *GroupUpsertOne) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetMessagesDispatchModelConfig(v) + }) +} + +// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateMessagesDispatchModelConfig() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2810,6 +2862,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk { }) } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (u *GroupUpsertBulk) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetMessagesDispatchModelConfig(v) + }) +} + +// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateMessagesDispatchModelConfig() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index aa1a83d4..7b6d6193 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -20,6 +20,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" + "github.com/Wei-Shaw/sub2api/internal/domain" ) // GroupUpdate is the builder for updating Group entities. @@ -552,6 +553,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate { return _u } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (_u *GroupUpdate) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpdate { + _u.mutation.SetMessagesDispatchModelConfig(v) + return _u +} + +// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupUpdate { + if v != nil { + _u.SetMessagesDispatchModelConfig(*v) + } + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -1012,6 +1027,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) } + if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { + _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1843,6 +1861,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO return _u } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (_u *GroupUpdateOne) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpdateOne { + _u.mutation.SetMessagesDispatchModelConfig(v) + return _u +} + +// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupUpdateOne { + if v != nil { + _u.SetMessagesDispatchModelConfig(*v) + } + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2333,6 +2365,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) } + if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { + _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 5400bf93..a7ae4af0 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -407,6 +407,7 @@ var ( {Name: "require_oauth_only", Type: field.TypeBool, Default: false}, {Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, + {Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index d206039a..594e5199 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -8246,6 +8246,7 @@ type GroupMutation struct { require_oauth_only *bool require_privacy_set *bool default_mapped_model *string + messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -9798,6 +9799,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() { m.default_mapped_model = nil } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) { + m.messages_dispatch_model_config = &damdmc +} + +// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation. +func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) { + v := m.messages_dispatch_model_config + if v == nil { + return + } + return *v, true +} + +// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err) + } + return oldValue.MessagesDispatchModelConfig, nil +} + +// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field. +func (m *GroupMutation) ResetMessagesDispatchModelConfig() { + m.messages_dispatch_model_config = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -10156,7 +10193,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 29) + fields := make([]string, 0, 30) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -10244,6 +10281,9 @@ func (m *GroupMutation) Fields() []string { if m.default_mapped_model != nil { fields = append(fields, group.FieldDefaultMappedModel) } + if m.messages_dispatch_model_config != nil { + fields = append(fields, group.FieldMessagesDispatchModelConfig) + } return fields } @@ -10310,6 +10350,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.RequirePrivacySet() case group.FieldDefaultMappedModel: return m.DefaultMappedModel() + case group.FieldMessagesDispatchModelConfig: + return m.MessagesDispatchModelConfig() } return nil, false } @@ -10377,6 +10419,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldRequirePrivacySet(ctx) case group.FieldDefaultMappedModel: return m.OldDefaultMappedModel(ctx) + case group.FieldMessagesDispatchModelConfig: + return m.OldMessagesDispatchModelConfig(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -10589,6 +10633,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetDefaultMappedModel(v) return nil + case group.FieldMessagesDispatchModelConfig: + v, ok := value.(domain.OpenAIMessagesDispatchModelConfig) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMessagesDispatchModelConfig(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -10929,6 +10980,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldDefaultMappedModel: m.ResetDefaultMappedModel() return nil + case group.FieldMessagesDispatchModelConfig: + m.ResetMessagesDispatchModelConfig() + return nil } return fmt.Errorf("unknown Group field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 803b7bc2..792f0566 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -28,6 +28,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" "github.com/Wei-Shaw/sub2api/ent/usersubscription" + "github.com/Wei-Shaw/sub2api/internal/domain" ) // The init function reads all schema descriptors with runtime code @@ -468,6 +469,10 @@ func init() { group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error) + // groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field. + groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor() + // group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field. + group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() _ = idempotencyrecordMixinFields0 diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 0eb89c18..d78a6898 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -141,6 +141,10 @@ func (Group) Fields() []ent.Field { MaxLen(100). Default(""). Comment("默认映射模型 ID,当账号级映射找不到时使用此值"), + field.JSON("messages_dispatch_model_config", domain.OpenAIMessagesDispatchModelConfig{}). + Default(domain.OpenAIMessagesDispatchModelConfig{}). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}). + Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"), } } diff --git a/backend/internal/domain/openai_messages_dispatch.go b/backend/internal/domain/openai_messages_dispatch.go new file mode 100644 index 00000000..6b018f1c --- /dev/null +++ b/backend/internal/domain/openai_messages_dispatch.go @@ -0,0 +1,10 @@ +package domain + +// OpenAIMessagesDispatchModelConfig controls how Anthropic /v1/messages +// requests are mapped onto OpenAI/Codex models. +type OpenAIMessagesDispatchModelConfig struct { + OpusMappedModel string `json:"opus_mapped_model,omitempty"` + SonnetMappedModel string `json:"sonnet_mapped_model,omitempty"` + HaikuMappedModel string `json:"haiku_mapped_model,omitempty"` + ExactModelMappings map[string]string `json:"exact_model_mappings,omitempty"` +} diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 458ed35d..8b6b056d 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -105,10 +105,11 @@ type CreateGroupRequest struct { // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch bool `json:"allow_messages_dispatch"` - RequireOAuthOnly bool `json:"require_oauth_only"` - RequirePrivacySet bool `json:"require_privacy_set"` - DefaultMappedModel string `json:"default_mapped_model"` + AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + RequireOAuthOnly bool `json:"require_oauth_only"` + RequirePrivacySet bool `json:"require_privacy_set"` + DefaultMappedModel string `json:"default_mapped_model"` + MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -139,10 +140,11 @@ type UpdateGroupRequest struct { // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string `json:"supported_model_scopes"` // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` - RequireOAuthOnly *bool `json:"require_oauth_only"` - RequirePrivacySet *bool `json:"require_privacy_set"` - DefaultMappedModel *string `json:"default_mapped_model"` + AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` + RequireOAuthOnly *bool `json:"require_oauth_only"` + RequirePrivacySet *bool `json:"require_privacy_set"` + DefaultMappedModel *string `json:"default_mapped_model"` + MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -257,6 +259,7 @@ func (h *GroupHandler) Create(c *gin.Context) { RequireOAuthOnly: req.RequireOAuthOnly, RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, + MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -307,6 +310,7 @@ func (h *GroupHandler) Update(c *gin.Context) { RequireOAuthOnly: req.RequireOAuthOnly, RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, + MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 2eab670e..478600eb 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -133,16 +133,17 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { return nil } out := &AdminGroup{ - Group: groupFromServiceBase(g), - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - MCPXMLInject: g.MCPXMLInject, - DefaultMappedModel: g.DefaultMappedModel, - SupportedModelScopes: g.SupportedModelScopes, - AccountCount: g.AccountCount, - ActiveAccountCount: g.ActiveAccountCount, - RateLimitedAccountCount: g.RateLimitedAccountCount, - SortOrder: g.SortOrder, + Group: groupFromServiceBase(g), + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.MCPXMLInject, + DefaultMappedModel: g.DefaultMappedModel, + MessagesDispatchModelConfig: g.MessagesDispatchModelConfig, + SupportedModelScopes: g.SupportedModelScopes, + AccountCount: g.AccountCount, + ActiveAccountCount: g.ActiveAccountCount, + RateLimitedAccountCount: g.RateLimitedAccountCount, + SortOrder: g.SortOrder, } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 82065deb..e026ca65 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -1,6 +1,10 @@ package dto -import "time" +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/domain" +) type User struct { ID int64 `json:"id"` @@ -112,7 +116,8 @@ type AdminGroup struct { MCPXMLInject bool `json:"mcp_xml_inject"` // OpenAI Messages 调度配置(仅 openai 平台使用) - DefaultMappedModel string `json:"default_mapped_model"` + DefaultMappedModel string `json:"default_mapped_model"` + MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index a075b586..1803cf30 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -58,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). - SetDefaultMappedModel(groupIn.DefaultMappedModel) + SetDefaultMappedModel(groupIn.DefaultMappedModel). + SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig) // 设置模型路由配置 if groupIn.ModelRouting != nil { @@ -124,7 +125,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). - SetDefaultMappedModel(groupIn.DefaultMappedModel) + SetDefaultMappedModel(groupIn.DefaultMappedModel). + SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 if groupIn.DailyLimitUSD != nil { diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 8032f871..c2553eee 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -152,10 +152,11 @@ type CreateGroupInput struct { // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch bool - DefaultMappedModel string - RequireOAuthOnly bool - RequirePrivacySet bool + AllowMessagesDispatch bool + DefaultMappedModel string + RequireOAuthOnly bool + RequirePrivacySet bool + MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -186,10 +187,11 @@ type UpdateGroupInput struct { // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch *bool - DefaultMappedModel *string - RequireOAuthOnly *bool - RequirePrivacySet *bool + AllowMessagesDispatch *bool + DefaultMappedModel *string + RequireOAuthOnly *bool + RequirePrivacySet *bool + MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -908,7 +910,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn RequireOAuthOnly: input.RequireOAuthOnly, RequirePrivacySet: input.RequirePrivacySet, DefaultMappedModel: input.DefaultMappedModel, + MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig), } + sanitizeGroupMessagesDispatchFields(group) if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err } @@ -1135,6 +1139,10 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.DefaultMappedModel != nil { group.DefaultMappedModel = *input.DefaultMappedModel } + if input.MessagesDispatchModelConfig != nil { + group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig) + } + sanitizeGroupMessagesDispatchFields(group) if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 536be0b5..fa676601 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -245,6 +245,116 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { require.Nil(t, repo.updated.ImagePrice4K) } +func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) { + repo := &groupRepoStubForAdmin{} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "dispatch-group", + Description: "dispatch config", + Platform: PlatformOpenAI, + RateMultiplier: 1.0, + MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: " gpt-5.4-high ", + SonnetMappedModel: " gpt-5.3-codex ", + HaikuMappedModel: " gpt-5.4-mini-medium ", + ExactModelMappings: map[string]string{ + " claude-sonnet-4-5-20250929 ": " gpt-5.2-high ", + }, + }, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.created) + require.Equal(t, OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: "gpt-5.4", + SonnetMappedModel: "gpt-5.3-codex", + HaikuMappedModel: "gpt-5.4-mini", + ExactModelMappings: map[string]string{ + "claude-sonnet-4-5-20250929": "gpt-5.2", + }, + }, repo.created.MessagesDispatchModelConfig) +} + +func TestAdminService_UpdateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) { + existingGroup := &Group{ + ID: 1, + Name: "existing-group", + Platform: PlatformOpenAI, + Status: StatusActive, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + MessagesDispatchModelConfig: &OpenAIMessagesDispatchModelConfig{ + SonnetMappedModel: " gpt-5.4-medium ", + ExactModelMappings: map[string]string{ + " claude-haiku-4-5-20251001 ": " gpt-5.4-mini-high ", + }, + }, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Equal(t, OpenAIMessagesDispatchModelConfig{ + SonnetMappedModel: "gpt-5.4", + ExactModelMappings: map[string]string{ + "claude-haiku-4-5-20251001": "gpt-5.4-mini", + }, + }, repo.updated.MessagesDispatchModelConfig) +} + +func TestAdminService_CreateGroup_ClearsMessagesDispatchFieldsForNonOpenAIPlatform(t *testing.T) { + repo := &groupRepoStubForAdmin{} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "anthropic-group", + Description: "non-openai", + Platform: PlatformAnthropic, + RateMultiplier: 1.0, + AllowMessagesDispatch: true, + DefaultMappedModel: "gpt-5.4", + MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: "gpt-5.4", + }, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.created) + require.False(t, repo.created.AllowMessagesDispatch) + require.Empty(t, repo.created.DefaultMappedModel) + require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.created.MessagesDispatchModelConfig) +} + +func TestAdminService_UpdateGroup_ClearsMessagesDispatchFieldsWhenPlatformChangesAwayFromOpenAI(t *testing.T) { + existingGroup := &Group{ + ID: 1, + Name: "existing-openai-group", + Platform: PlatformOpenAI, + Status: StatusActive, + AllowMessagesDispatch: true, + DefaultMappedModel: "gpt-5.4", + MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{ + SonnetMappedModel: "gpt-5.3-codex", + }, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + Platform: ptrString(PlatformAnthropic), + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Equal(t, PlatformAnthropic, repo.updated.Platform) + require.False(t, repo.updated.AllowMessagesDispatch) + require.Empty(t, repo.updated.DefaultMappedModel) + require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.updated.MessagesDispatchModelConfig) +} + func TestAdminService_ListGroups_WithSearch(t *testing.T) { // 测试: // 1. search 参数正常传递到 repository 层 diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index d59af9e1..12262613 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -3,8 +3,12 @@ package service import ( "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/domain" ) +type OpenAIMessagesDispatchModelConfig = domain.OpenAIMessagesDispatchModelConfig + type Group struct { ID int64 Name string @@ -49,10 +53,11 @@ type Group struct { SortOrder int // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch bool - RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini) - RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) - DefaultMappedModel string + AllowMessagesDispatch bool + RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini) + RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) + DefaultMappedModel string + MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/openai_messages_dispatch.go b/backend/internal/service/openai_messages_dispatch.go new file mode 100644 index 00000000..f2c1ad3c --- /dev/null +++ b/backend/internal/service/openai_messages_dispatch.go @@ -0,0 +1,100 @@ +package service + +import "strings" + +const ( + defaultOpenAIMessagesDispatchOpusMappedModel = "gpt-5.4" + defaultOpenAIMessagesDispatchSonnetMappedModel = "gpt-5.3-codex" + defaultOpenAIMessagesDispatchHaikuMappedModel = "gpt-5.4-mini" +) + +func normalizeOpenAIMessagesDispatchMappedModel(model string) string { + model = NormalizeOpenAICompatRequestedModel(strings.TrimSpace(model)) + return strings.TrimSpace(model) +} + +func normalizeOpenAIMessagesDispatchModelConfig(cfg OpenAIMessagesDispatchModelConfig) OpenAIMessagesDispatchModelConfig { + out := OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.OpusMappedModel), + SonnetMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.SonnetMappedModel), + HaikuMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.HaikuMappedModel), + } + + if len(cfg.ExactModelMappings) > 0 { + out.ExactModelMappings = make(map[string]string, len(cfg.ExactModelMappings)) + for requestedModel, mappedModel := range cfg.ExactModelMappings { + requestedModel = strings.TrimSpace(requestedModel) + mappedModel = normalizeOpenAIMessagesDispatchMappedModel(mappedModel) + if requestedModel == "" || mappedModel == "" { + continue + } + out.ExactModelMappings[requestedModel] = mappedModel + } + if len(out.ExactModelMappings) == 0 { + out.ExactModelMappings = nil + } + } + + return out +} + +func claudeMessagesDispatchFamily(model string) string { + normalized := strings.ToLower(strings.TrimSpace(model)) + if !strings.HasPrefix(normalized, "claude") { + return "" + } + switch { + case strings.Contains(normalized, "opus"): + return "opus" + case strings.Contains(normalized, "sonnet"): + return "sonnet" + case strings.Contains(normalized, "haiku"): + return "haiku" + default: + return "" + } +} + +func (g *Group) ResolveMessagesDispatchModel(requestedModel string) string { + if g == nil { + return "" + } + requestedModel = strings.TrimSpace(requestedModel) + if requestedModel == "" { + return "" + } + + cfg := normalizeOpenAIMessagesDispatchModelConfig(g.MessagesDispatchModelConfig) + if mappedModel := strings.TrimSpace(cfg.ExactModelMappings[requestedModel]); mappedModel != "" { + return mappedModel + } + + switch claudeMessagesDispatchFamily(requestedModel) { + case "opus": + if mappedModel := strings.TrimSpace(cfg.OpusMappedModel); mappedModel != "" { + return mappedModel + } + return defaultOpenAIMessagesDispatchOpusMappedModel + case "sonnet": + if mappedModel := strings.TrimSpace(cfg.SonnetMappedModel); mappedModel != "" { + return mappedModel + } + return defaultOpenAIMessagesDispatchSonnetMappedModel + case "haiku": + if mappedModel := strings.TrimSpace(cfg.HaikuMappedModel); mappedModel != "" { + return mappedModel + } + return defaultOpenAIMessagesDispatchHaikuMappedModel + default: + return "" + } +} + +func sanitizeGroupMessagesDispatchFields(g *Group) { + if g == nil || g.Platform == PlatformOpenAI { + return + } + g.AllowMessagesDispatch = false + g.DefaultMappedModel = "" + g.MessagesDispatchModelConfig = OpenAIMessagesDispatchModelConfig{} +} diff --git a/backend/internal/service/openai_messages_dispatch_test.go b/backend/internal/service/openai_messages_dispatch_test.go new file mode 100644 index 00000000..a625aadd --- /dev/null +++ b/backend/internal/service/openai_messages_dispatch_test.go @@ -0,0 +1,27 @@ +package service + +import "testing" + +import "github.com/stretchr/testify/require" + +func TestNormalizeOpenAIMessagesDispatchModelConfig(t *testing.T) { + t.Parallel() + + cfg := normalizeOpenAIMessagesDispatchModelConfig(OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: " gpt-5.4-high ", + SonnetMappedModel: "gpt-5.3-codex", + HaikuMappedModel: " gpt-5.4-mini-medium ", + ExactModelMappings: map[string]string{ + " claude-sonnet-4-5-20250929 ": " gpt-5.2-high ", + "": "gpt-5.4", + "claude-opus-4-6": " ", + }, + }) + + require.Equal(t, "gpt-5.4", cfg.OpusMappedModel) + require.Equal(t, "gpt-5.3-codex", cfg.SonnetMappedModel) + require.Equal(t, "gpt-5.4-mini", cfg.HaikuMappedModel) + require.Equal(t, map[string]string{ + "claude-sonnet-4-5-20250929": "gpt-5.2", + }, cfg.ExactModelMappings) +} diff --git a/backend/migrations/091_add_group_messages_dispatch_model_config.sql b/backend/migrations/091_add_group_messages_dispatch_model_config.sql new file mode 100644 index 00000000..8ddfcb0f --- /dev/null +++ b/backend/migrations/091_add_group_messages_dispatch_model_config.sql @@ -0,0 +1,2 @@ +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS messages_dispatch_model_config JSONB NOT NULL DEFAULT '{}'::jsonb; -- GitLab From 4de4823a65bccf199d1507af5dfd4a14655e8236 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Thu, 9 Apr 2026 12:29:49 +0800 Subject: [PATCH 08/52] =?UTF-8?q?feat(openai):=20=E6=94=AF=E6=8C=81message?= =?UTF-8?q?s=E6=A8=A1=E5=9E=8B=E6=98=A0=E5=B0=84=E4=B8=8Einstructions?= =?UTF-8?q?=E6=A8=A1=E6=9D=BF=E6=B3=A8=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/config/config.go | 14 +++ backend/internal/config/config_test.go | 19 ++++ .../handler/openai_gateway_handler.go | 43 +++----- .../handler/openai_gateway_handler_test.go | 41 ++++++- backend/internal/pkg/apicompat/types.go | 6 +- .../openai_codex_instructions_template.go | 55 ++++++++++ .../service/openai_compat_model_test.go | 101 ++++++++++++++++++ .../service/openai_gateway_messages.go | 18 ++++ deploy/codex-instructions.md.tmpl | 5 + deploy/config.example.yaml | 32 ++++-- deploy/docker-compose.yml | 39 +++++-- 11 files changed, 326 insertions(+), 47 deletions(-) create mode 100644 backend/internal/service/openai_codex_instructions_template.go create mode 100644 deploy/codex-instructions.md.tmpl diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index ad023dc1..d3d2dd6d 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -318,6 +318,12 @@ type GatewayConfig struct { // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 // 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。 ForceCodexCLI bool `mapstructure:"force_codex_cli"` + // ForcedCodexInstructionsTemplateFile: 服务端强制附加到 Codex 顶层 instructions 的模板文件路径。 + // 模板渲染后会直接覆盖最终 instructions;若需要保留客户端 system 转换结果,请在模板中显式引用 {{ .ExistingInstructions }}。 + ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"` + // ForcedCodexInstructionsTemplate: 启动时从模板文件读取并缓存的模板内容。 + // 该字段不直接参与配置反序列化,仅用于请求热路径避免重复读盘。 + ForcedCodexInstructionsTemplate string `mapstructure:"-"` // OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头 // 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。 OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` @@ -983,6 +989,14 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment) cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel)) cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath) + cfg.Gateway.ForcedCodexInstructionsTemplateFile = strings.TrimSpace(cfg.Gateway.ForcedCodexInstructionsTemplateFile) + if cfg.Gateway.ForcedCodexInstructionsTemplateFile != "" { + content, err := os.ReadFile(cfg.Gateway.ForcedCodexInstructionsTemplateFile) + if err != nil { + return nil, fmt.Errorf("read forced codex instructions template %q: %w", cfg.Gateway.ForcedCodexInstructionsTemplateFile, err) + } + cfg.Gateway.ForcedCodexInstructionsTemplate = string(content) + } // 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。 // 新键未配置(<=0)时回退旧键;新键优先。 diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 2de5451e..8cb23026 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1,6 +1,8 @@ package config import ( + "os" + "path/filepath" "strings" "testing" "time" @@ -223,6 +225,23 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) { } } +func TestLoadForcedCodexInstructionsTemplate(t *testing.T) { + resetViperWithJWTSecret(t) + + tempDir := t.TempDir() + templatePath := filepath.Join(tempDir, "codex-instructions.md.tmpl") + configPath := filepath.Join(tempDir, "config.yaml") + + require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644)) + require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+templatePath+"\"\n"), 0o644)) + t.Setenv("DATA_DIR", tempDir) + + cfg, err := Load() + require.NoError(t, err) + require.Equal(t, templatePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile) + require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate) +} + func TestLoadDefaultSecurityToggles(t *testing.T) { resetViperWithJWTSecret(t) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 4747ccfe..5319b55d 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -47,6 +47,13 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode return strings.TrimSpace(apiKey.Group.DefaultMappedModel) } +func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string { + if apiKey == nil || apiKey.Group == nil { + return "" + } + return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel)) +} + // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, @@ -551,6 +558,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { } reqModel := modelResult.String() routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel) + preferredMappedModel := resolveOpenAIMessagesDispatchMappedModel(apiKey, reqModel) reqStream := gjson.GetBytes(body, "stream").Bool() reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) @@ -609,17 +617,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { failedAccountIDs := make(map[int64]struct{}) sameAccountRetryCount := make(map[int64]int) var lastFailoverErr *service.UpstreamFailoverError + effectiveMappedModel := preferredMappedModel for { - // 清除上一次迭代的降级模型标记,避免残留影响本次迭代 - c.Set("openai_messages_fallback_model", "") + currentRoutingModel := routingModel + if effectiveMappedModel != "" { + currentRoutingModel = effectiveMappedModel + } reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( c.Request.Context(), apiKey.GroupID, "", // no previous_response_id sessionHash, - routingModel, + currentRoutingModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, ) @@ -628,29 +639,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)), ) - // 首次调度失败 + 有默认映射模型 → 用默认模型重试 if len(failedAccountIDs) == 0 { - defaultModel := "" - if apiKey.Group != nil { - defaultModel = apiKey.Group.DefaultMappedModel - } - if defaultModel != "" && defaultModel != routingModel { - reqLog.Info("openai_messages.fallback_to_default_model", - zap.String("default_mapped_model", defaultModel), - ) - selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler( - c.Request.Context(), - apiKey.GroupID, - "", - sessionHash, - defaultModel, - failedAccountIDs, - service.OpenAIUpstreamTransportAny, - ) - if err == nil && selection != nil { - c.Set("openai_messages_fallback_model", defaultModel) - } - } if err != nil { h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return @@ -682,9 +671,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的 - // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。 - defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model")) + defaultMappedModel := strings.TrimSpace(effectiveMappedModel) // 应用渠道模型映射到请求体 forwardBody := body if channelMappingMsg.Mapped { diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index 7bbf94ec..d299fb81 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -360,7 +360,7 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) { require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 ")) }) - t.Run("uses_group_default_on_normal_path", func(t *testing.T) { + t.Run("uses_group_default_when_explicit_fallback_absent", func(t *testing.T) { apiKey := &service.APIKey{ Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, } @@ -376,6 +376,45 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) { }) } +func TestResolveOpenAIMessagesDispatchMappedModel(t *testing.T) { + t.Run("exact_claude_model_override_wins", func(t *testing.T) { + apiKey := &service.APIKey{ + Group: &service.Group{ + MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{ + SonnetMappedModel: "gpt-5.2", + ExactModelMappings: map[string]string{ + "claude-sonnet-4-5-20250929": "gpt-5.4-mini-high", + }, + }, + }, + } + require.Equal(t, "gpt-5.4-mini", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929")) + }) + + t.Run("uses_family_default_when_no_override", func(t *testing.T) { + apiKey := &service.APIKey{Group: &service.Group{}} + require.Equal(t, "gpt-5.4", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-opus-4-6")) + require.Equal(t, "gpt-5.3-codex", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929")) + require.Equal(t, "gpt-5.4-mini", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-haiku-4-5-20251001")) + }) + + t.Run("returns_empty_for_non_claude_or_missing_group", func(t *testing.T) { + require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(nil, "claude-sonnet-4-5-20250929")) + require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(&service.APIKey{}, "claude-sonnet-4-5-20250929")) + require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(&service.APIKey{Group: &service.Group{}}, "gpt-5.4")) + }) + + t.Run("does_not_fall_back_to_group_default_mapped_model", func(t *testing.T) { + apiKey := &service.APIKey{ + Group: &service.Group{ + DefaultMappedModel: "gpt-5.4", + }, + } + require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(apiKey, "gpt-5.4")) + require.Equal(t, "gpt-5.3-codex", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929")) + }) +} + func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index b724a5ed..b383f867 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -28,7 +28,7 @@ type AnthropicRequest struct { // AnthropicOutputConfig controls output generation parameters. type AnthropicOutputConfig struct { - Effort string `json:"effort,omitempty"` // "low" | "medium" | "high" + Effort string `json:"effort,omitempty"` // "low" | "medium" | "high" | "max" } // AnthropicThinking configures extended thinking in the Anthropic API. @@ -167,7 +167,7 @@ type ResponsesRequest struct { // ResponsesReasoning configures reasoning effort in the Responses API. type ResponsesReasoning struct { - Effort string `json:"effort"` // "low" | "medium" | "high" + Effort string `json:"effort"` // "low" | "medium" | "high" | "xhigh" Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed" } @@ -345,7 +345,7 @@ type ChatCompletionsRequest struct { StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"` Tools []ChatTool `json:"tools,omitempty"` ToolChoice json.RawMessage `json:"tool_choice,omitempty"` - ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" + ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" | "xhigh" ServiceTier string `json:"service_tier,omitempty"` Stop json.RawMessage `json:"stop,omitempty"` // string or []string diff --git a/backend/internal/service/openai_codex_instructions_template.go b/backend/internal/service/openai_codex_instructions_template.go new file mode 100644 index 00000000..5588c73c --- /dev/null +++ b/backend/internal/service/openai_codex_instructions_template.go @@ -0,0 +1,55 @@ +package service + +import ( + "bytes" + "fmt" + "strings" + "text/template" +) + +type forcedCodexInstructionsTemplateData struct { + ExistingInstructions string + OriginalModel string + NormalizedModel string + BillingModel string + UpstreamModel string +} + +func applyForcedCodexInstructionsTemplate( + reqBody map[string]any, + templateText string, + data forcedCodexInstructionsTemplateData, +) (bool, error) { + rendered, err := renderForcedCodexInstructionsTemplate(templateText, data) + if err != nil { + return false, err + } + if rendered == "" { + return false, nil + } + + existing, _ := reqBody["instructions"].(string) + if strings.TrimSpace(existing) == rendered { + return false, nil + } + + reqBody["instructions"] = rendered + return true, nil +} + +func renderForcedCodexInstructionsTemplate( + templateText string, + data forcedCodexInstructionsTemplateData, +) (string, error) { + tmpl, err := template.New("forced_codex_instructions").Option("missingkey=zero").Parse(templateText) + if err != nil { + return "", fmt.Errorf("parse forced codex instructions template: %w", err) + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + return "", fmt.Errorf("render forced codex instructions template: %w", err) + } + + return strings.TrimSpace(buf.String()), nil +} diff --git a/backend/internal/service/openai_compat_model_test.go b/backend/internal/service/openai_compat_model_test.go index 32c646d4..4396c15f 100644 --- a/backend/internal/service/openai_compat_model_test.go +++ b/backend/internal/service/openai_compat_model_test.go @@ -6,9 +6,12 @@ import ( "io" "net/http" "net/http/httptest" + "os" + "path/filepath" "strings" "testing" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -127,3 +130,101 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T t.Logf("upstream body: %s", string(upstream.lastBody)) t.Logf("response body: %s", rec.Body.String()) } + +func TestForwardAsAnthropic_ForcedCodexInstructionsTemplatePrependsRenderedInstructions(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + templateDir := t.TempDir() + templatePath := filepath.Join(templateDir, "codex-instructions.md.tmpl") + require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644)) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"system":"client-system","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_forced"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ + ForcedCodexInstructionsTemplateFile: templatePath, + ForcedCodexInstructionsTemplate: "server-prefix\n\n{{ .ExistingInstructions }}", + }}, + httpUpstream: upstream, + } + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "server-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String()) +} + +func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateContent(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4","max_tokens":16,"system":"client-system","messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_forced_cached"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ + ForcedCodexInstructionsTemplateFile: "/path/that/should/not/be/read.tmpl", + ForcedCodexInstructionsTemplate: "cached-prefix\n\n{{ .ExistingInstructions }}", + }}, + httpUpstream: upstream, + } + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String()) +} diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 6f53928b..7a4862d3 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -86,6 +86,24 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return nil, fmt.Errorf("unmarshal for codex transform: %w", err) } codexResult := applyCodexOAuthTransform(reqBody, false, false) + forcedTemplateText := "" + if s.cfg != nil { + forcedTemplateText = s.cfg.Gateway.ForcedCodexInstructionsTemplate + } + templateUpstreamModel := upstreamModel + if codexResult.NormalizedModel != "" { + templateUpstreamModel = codexResult.NormalizedModel + } + existingInstructions, _ := reqBody["instructions"].(string) + if _, err := applyForcedCodexInstructionsTemplate(reqBody, forcedTemplateText, forcedCodexInstructionsTemplateData{ + ExistingInstructions: strings.TrimSpace(existingInstructions), + OriginalModel: originalModel, + NormalizedModel: normalizedModel, + BillingModel: billingModel, + UpstreamModel: templateUpstreamModel, + }); err != nil { + return nil, err + } if codexResult.NormalizedModel != "" { upstreamModel = codexResult.NormalizedModel } diff --git a/deploy/codex-instructions.md.tmpl b/deploy/codex-instructions.md.tmpl new file mode 100644 index 00000000..87ad0a3d --- /dev/null +++ b/deploy/codex-instructions.md.tmpl @@ -0,0 +1,5 @@ +You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer. + +{{ if .ExistingInstructions }} +{{ .ExistingInstructions }} +{{ end }} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 45440761..6fd5fb8f 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -202,6 +202,32 @@ gateway: # # 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。 force_codex_cli: false + # Optional: template file used to build the final top-level Codex `instructions`. + # 可选:用于构建最终 Codex 顶层 `instructions` 的模板文件路径。 + # + # This is applied on the `/v1/messages -> Responses/Codex` conversion path, + # after Claude `system` has already been normalized into Codex `instructions`. + # 该模板作用于 `/v1/messages -> Responses/Codex` 转换链路,且发生在 Claude `system` + # 已经被归一化为 Codex `instructions` 之后。 + # + # The template can reference: + # 模板可引用: + # - {{ .ExistingInstructions }} : converted client instructions/system + # - {{ .OriginalModel }} : original requested model + # - {{ .NormalizedModel }} : normalized routing model + # - {{ .BillingModel }} : billing model + # - {{ .UpstreamModel }} : final upstream model + # + # If you want to preserve client system prompts, keep {{ .ExistingInstructions }} + # somewhere in the template. If omitted, the template output fully replaces it. + # 如需保留客户端 system 提示词,请在模板中显式包含 {{ .ExistingInstructions }}。 + # 若省略,则模板输出会完全覆盖它。 + # + # Docker users can mount a host file to /app/data/codex-instructions.md.tmpl + # and point this field there. + # Docker 用户可将宿主机文件挂载到 /app/data/codex-instructions.md.tmpl, + # 然后把本字段指向该路径。 + forced_codex_instructions_template_file: "" # OpenAI 透传模式是否放行客户端超时头(如 x-stainless-timeout) # 默认 false:过滤超时头,降低上游提前断流风险。 openai_passthrough_allow_timeout_headers: false @@ -347,12 +373,6 @@ gateway: # Enable batch load calculation for scheduling # 启用调度批量负载计算 load_batch_enabled: true - # Snapshot bucket MGET chunk size - # 调度快照分桶读取时的 MGET 分块大小 - snapshot_mget_chunk_size: 128 - # Snapshot bucket write chunk size - # 调度快照重建写入时的分块大小 - snapshot_write_chunk_size: 256 # Slot cleanup interval (duration) # 并发槽位清理周期(时间段) slot_cleanup_interval: 30s diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index a0bc1a60..3a714260 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -31,6 +31,10 @@ services: # Optional: Mount custom config.yaml (uncomment and create the file first) # Copy config.example.yaml to config.yaml, modify it, then uncomment: # - ./config.yaml:/app/data/config.yaml + # Optional: Mount a custom Codex instructions template file, then point + # gateway.forced_codex_instructions_template_file at /app/data/codex-instructions.md.tmpl + # in config.yaml. + # - ./codex-instructions.md.tmpl:/app/data/codex-instructions.md.tmpl:ro environment: # ======================================================================= # Auto Setup (REQUIRED for Docker deployment) @@ -146,7 +150,17 @@ services: networks: - sub2api-network healthcheck: - test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] + test: + [ + "CMD", + "wget", + "-q", + "-T", + "5", + "-O", + "/dev/null", + "http://localhost:8080/health", + ] interval: 30s timeout: 10s retries: 3 @@ -177,11 +191,17 @@ services: networks: - sub2api-network healthcheck: - test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"] + test: + [ + "CMD-SHELL", + "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}", + ] interval: 10s timeout: 5s retries: 5 start_period: 10s + ports: + - 5432:5432 # 注意:不暴露端口到宿主机,应用通过内部网络连接 # 如需调试,可临时添加:ports: ["127.0.0.1:5433:5432"] @@ -199,12 +219,12 @@ services: volumes: - redis_data:/data command: > - sh -c ' - redis-server - --save 60 1 - --appendonly yes - --appendfsync everysec - ${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}' + sh -c ' + redis-server + --save 60 1 + --appendonly yes + --appendfsync everysec + ${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}' environment: - TZ=${TZ:-Asia/Shanghai} # REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag) @@ -217,7 +237,8 @@ services: timeout: 5s retries: 5 start_period: 5s - + ports: + - 6379:6379 # ============================================================================= # Volumes # ============================================================================= -- GitLab From d765359f4bed38ec1f3535365dd58ea367dfa5be Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Thu, 9 Apr 2026 12:30:06 +0800 Subject: [PATCH 09/52] =?UTF-8?q?test(admin):=20=E5=A2=9E=E5=8A=A0messages?= =?UTF-8?q?=E8=B0=83=E5=BA=A6=E8=A1=A8=E5=8D=95=E7=8A=B6=E6=80=81=E8=BD=AC?= =?UTF-8?q?=E6=8D=A2=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/src/types/index.ts | 10 ++ .../__tests__/groupsMessagesDispatch.spec.ts | 94 +++++++++++++++++++ .../src/views/admin/groupsMessagesDispatch.ts | 72 ++++++++++++++ 3 files changed, 176 insertions(+) create mode 100644 frontend/src/views/admin/__tests__/groupsMessagesDispatch.spec.ts create mode 100644 frontend/src/views/admin/groupsMessagesDispatch.ts diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 580126c8..9f2c2755 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -366,6 +366,13 @@ export type GroupPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' export type SubscriptionType = 'standard' | 'subscription' +export interface OpenAIMessagesDispatchModelConfig { + opus_mapped_model?: string + sonnet_mapped_model?: string + haiku_mapped_model?: string + exact_model_mappings?: Record +} + export interface Group { id: number name: string @@ -388,6 +395,8 @@ export interface Group { fallback_group_id_on_invalid_request: number | null // OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程) allow_messages_dispatch?: boolean + default_mapped_model?: string + messages_dispatch_model_config?: OpenAIMessagesDispatchModelConfig require_oauth_only: boolean require_privacy_set: boolean created_at: string @@ -414,6 +423,7 @@ export interface AdminGroup extends Group { // OpenAI Messages 调度配置(仅 openai 平台使用) default_mapped_model?: string + messages_dispatch_model_config?: OpenAIMessagesDispatchModelConfig // 分组排序 sort_order: number diff --git a/frontend/src/views/admin/__tests__/groupsMessagesDispatch.spec.ts b/frontend/src/views/admin/__tests__/groupsMessagesDispatch.spec.ts new file mode 100644 index 00000000..15f42224 --- /dev/null +++ b/frontend/src/views/admin/__tests__/groupsMessagesDispatch.spec.ts @@ -0,0 +1,94 @@ +import { describe, expect, it } from "vitest"; + +import { + createDefaultMessagesDispatchFormState, + messagesDispatchConfigToFormState, + messagesDispatchFormStateToConfig, + resetMessagesDispatchFormState, +} from "../groupsMessagesDispatch"; + +describe("groupsMessagesDispatch", () => { + it("returns the expected default form state", () => { + expect(createDefaultMessagesDispatchFormState()).toEqual({ + allow_messages_dispatch: false, + opus_mapped_model: "gpt-5.4", + sonnet_mapped_model: "gpt-5.3-codex", + haiku_mapped_model: "gpt-5.4-mini", + exact_model_mappings: [], + }); + }); + + it("sanitizes exact model mapping rows when converting to config", () => { + const config = messagesDispatchFormStateToConfig({ + allow_messages_dispatch: true, + opus_mapped_model: " gpt-5.4 ", + sonnet_mapped_model: "gpt-5.3-codex", + haiku_mapped_model: " gpt-5.4-mini ", + exact_model_mappings: [ + { + claude_model: " claude-sonnet-4-5-20250929 ", + target_model: " gpt-5.2 ", + }, + { claude_model: "", target_model: "gpt-5.4" }, + { claude_model: "claude-opus-4-6", target_model: " " }, + ], + }); + + expect(config).toEqual({ + opus_mapped_model: "gpt-5.4", + sonnet_mapped_model: "gpt-5.3-codex", + haiku_mapped_model: "gpt-5.4-mini", + exact_model_mappings: { + "claude-sonnet-4-5-20250929": "gpt-5.2", + }, + }); + }); + + it("hydrates form state from api config", () => { + expect( + messagesDispatchConfigToFormState({ + opus_mapped_model: "gpt-5.4", + sonnet_mapped_model: "gpt-5.2", + haiku_mapped_model: "gpt-5.4-mini", + exact_model_mappings: { + "claude-opus-4-6": "gpt-5.4", + "claude-haiku-4-5-20251001": "gpt-5.4-mini", + }, + }), + ).toEqual({ + allow_messages_dispatch: false, + opus_mapped_model: "gpt-5.4", + sonnet_mapped_model: "gpt-5.2", + haiku_mapped_model: "gpt-5.4-mini", + exact_model_mappings: [ + { + claude_model: "claude-haiku-4-5-20251001", + target_model: "gpt-5.4-mini", + }, + { claude_model: "claude-opus-4-6", target_model: "gpt-5.4" }, + ], + }); + }); + + it("resets mutable form state when platform switches away from openai", () => { + const state = { + allow_messages_dispatch: true, + opus_mapped_model: "gpt-5.2", + sonnet_mapped_model: "gpt-5.4", + haiku_mapped_model: "gpt-5.1", + exact_model_mappings: [ + { claude_model: "claude-opus-4-6", target_model: "gpt-5.4" }, + ], + }; + + resetMessagesDispatchFormState(state); + + expect(state).toEqual({ + allow_messages_dispatch: false, + opus_mapped_model: "gpt-5.4", + sonnet_mapped_model: "gpt-5.3-codex", + haiku_mapped_model: "gpt-5.4-mini", + exact_model_mappings: [], + }); + }); +}); diff --git a/frontend/src/views/admin/groupsMessagesDispatch.ts b/frontend/src/views/admin/groupsMessagesDispatch.ts new file mode 100644 index 00000000..b367091c --- /dev/null +++ b/frontend/src/views/admin/groupsMessagesDispatch.ts @@ -0,0 +1,72 @@ +import type { OpenAIMessagesDispatchModelConfig } from "@/types"; + +export interface MessagesDispatchMappingRow { + claude_model: string; + target_model: string; +} + +export interface MessagesDispatchFormState { + allow_messages_dispatch: boolean; + opus_mapped_model: string; + sonnet_mapped_model: string; + haiku_mapped_model: string; + exact_model_mappings: MessagesDispatchMappingRow[]; +} + +export function createDefaultMessagesDispatchFormState(): MessagesDispatchFormState { + return { + allow_messages_dispatch: false, + opus_mapped_model: "gpt-5.4", + sonnet_mapped_model: "gpt-5.3-codex", + haiku_mapped_model: "gpt-5.4-mini", + exact_model_mappings: [], + }; +} + +export function messagesDispatchConfigToFormState( + config?: OpenAIMessagesDispatchModelConfig | null, +): MessagesDispatchFormState { + const defaults = createDefaultMessagesDispatchFormState(); + const exactMappings = Object.entries(config?.exact_model_mappings || {}) + .sort(([left], [right]) => left.localeCompare(right)) + .map(([claude_model, target_model]) => ({ claude_model, target_model })); + + return { + allow_messages_dispatch: false, + opus_mapped_model: + config?.opus_mapped_model?.trim() || defaults.opus_mapped_model, + sonnet_mapped_model: + config?.sonnet_mapped_model?.trim() || defaults.sonnet_mapped_model, + haiku_mapped_model: + config?.haiku_mapped_model?.trim() || defaults.haiku_mapped_model, + exact_model_mappings: exactMappings, + }; +} + +export function messagesDispatchFormStateToConfig( + state: MessagesDispatchFormState, +): OpenAIMessagesDispatchModelConfig { + const exactModelMappings = Object.fromEntries( + state.exact_model_mappings + .map((row) => [row.claude_model.trim(), row.target_model.trim()] as const) + .filter(([claudeModel, targetModel]) => claudeModel && targetModel), + ); + + return { + opus_mapped_model: state.opus_mapped_model.trim(), + sonnet_mapped_model: state.sonnet_mapped_model.trim(), + haiku_mapped_model: state.haiku_mapped_model.trim(), + exact_model_mappings: exactModelMappings, + }; +} + +export function resetMessagesDispatchFormState( + target: MessagesDispatchFormState, +): void { + const defaults = createDefaultMessagesDispatchFormState(); + target.allow_messages_dispatch = defaults.allow_messages_dispatch; + target.opus_mapped_model = defaults.opus_mapped_model; + target.sonnet_mapped_model = defaults.sonnet_mapped_model; + target.haiku_mapped_model = defaults.haiku_mapped_model; + target.exact_model_mappings = []; +} -- GitLab From de9b9c9dfb2f7cd67a222551e0ab4765fa15483e Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Thu, 9 Apr 2026 12:30:25 +0800 Subject: [PATCH 10/52] =?UTF-8?q?feat(admin):=20=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=88=86=E7=BB=84=20messages=20=E8=B0=83=E5=BA=A6=E6=98=A0?= =?UTF-8?q?=E5=B0=84=E9=85=8D=E7=BD=AE=E7=95=8C=E9=9D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/src/views/admin/GroupsView.vue | 2803 ++++++++++++++++------- 1 file changed, 1959 insertions(+), 844 deletions(-) diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index 5bfe62c3..99fc7c31 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -2,7 +2,9 @@ @@ -218,21 +289,23 @@ class="flex flex-col items-center gap-0.5 rounded-lg p-1.5 text-gray-500 transition-colors hover:bg-gray-100 hover:text-primary-600 dark:hover:bg-dark-700 dark:hover:text-primary-400" > - {{ t('common.edit') }} + {{ t("common.edit") }}
@@ -267,9 +340,13 @@ width="normal" @close="closeCreateModal" > - +
- +
- +
- + { + const val = Number((e.target as HTMLSelectElement).value); + if ( + val && + !createForm.copy_accounts_from_group_ids.includes(val) + ) { + createForm.copy_accounts_from_group_ids.push(val); + } + (e.target as HTMLSelectElement).value = ''; } - (e.target as HTMLSelectElement).value = '' - }" + " > - + -

{{ t('admin.groups.copyAccounts.hint') }}

+

{{ t("admin.groups.copyAccounts.hint") }}

- + -

{{ t('admin.groups.rateMultiplierHint') }}

+

{{ t("admin.groups.rateMultiplierHint") }}

-
+
@@ -388,20 +500,32 @@ class="cursor-help text-gray-400 transition-colors hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400" /> -
-
-

{{ t('admin.groups.exclusiveTooltip.title') }}

+
+
+

+ {{ t("admin.groups.exclusiveTooltip.title") }} +

- {{ t('admin.groups.exclusiveTooltip.description') }} + {{ t("admin.groups.exclusiveTooltip.description") }}

- {{ t('admin.groups.exclusiveTooltip.example') }} - {{ t('admin.groups.exclusiveTooltip.exampleContent') }} + + {{ t("admin.groups.exclusiveTooltip.example") }} + {{ t("admin.groups.exclusiveTooltip.exampleContent") }}

-
+
@@ -412,18 +536,24 @@ @click="createForm.is_exclusive = !createForm.is_exclusive" :class="[ 'relative inline-flex h-6 w-11 items-center rounded-full transition-colors', - createForm.is_exclusive ? 'bg-primary-500' : 'bg-gray-300 dark:bg-dark-600' + createForm.is_exclusive + ? 'bg-primary-500' + : 'bg-gray-300 dark:bg-dark-600', ]" > - {{ createForm.is_exclusive ? t('admin.groups.exclusive') : t('admin.groups.public') }} + {{ + createForm.is_exclusive + ? t("admin.groups.exclusive") + : t("admin.groups.public") + }}
@@ -431,9 +561,16 @@
- - +

+ {{ t("admin.groups.subscription.typeHint") }} +

@@ -442,7 +579,9 @@ class="space-y-4 border-l-2 border-primary-200 pl-4 dark:border-primary-800" >
- +
- +
- + -
-