Unverified Commit ddf80f5e authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1799 from IanShaw027/rebuild/auth-identity-foundation

fix(auth,payment,profile): 修复认证身份和支付系统的后续问题
parents 4d0483f5 c048ca80
......@@ -28,6 +28,26 @@ jobs:
working-directory: backend
run: make test-integration
frontend:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 9
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: '20'
cache: 'pnpm'
cache-dependency-path: frontend/pnpm-lock.yaml
- name: Install frontend dependencies
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: Frontend typecheck and critical vitest
run: make test-frontend
golangci-lint:
runs-on: ubuntu-latest
steps:
......
.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan
.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-frontend-critical test-datamanagementd secret-scan
FRONTEND_CRITICAL_VITEST := \
src/views/auth/__tests__/LinuxDoCallbackView.spec.ts \
src/views/auth/__tests__/WechatCallbackView.spec.ts \
src/views/user/__tests__/PaymentView.spec.ts \
src/views/user/__tests__/PaymentResultView.spec.ts \
src/components/user/profile/__tests__/ProfileInfoCard.spec.ts \
src/views/admin/__tests__/SettingsView.spec.ts
# 一键编译前后端
build: build-backend build-frontend
......@@ -24,6 +32,10 @@ test-backend:
test-frontend:
@pnpm --dir frontend run lint:check
@pnpm --dir frontend run typecheck
@$(MAKE) test-frontend-critical
test-frontend-critical:
@pnpm --dir frontend exec vitest run $(FRONTEND_CRITICAL_VITEST)
test-datamanagementd:
@cd datamanagement && go test ./...
......
......@@ -42,10 +42,18 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
- **Smart Scheduling** - Intelligent account selection with sticky sessions
- **Concurrency Control** - Per-user and per-account concurrency limits
- **Rate Limiting** - Configurable request and token rate limits
- **Built-in Payment System** - Supports EasyPay, Alipay, WeChat Pay, and Stripe for user self-service top-up, no separate payment service needed ([Configuration Guide](docs/PAYMENT.md))
- **Built-in Payment System** - Supports EasyPay, Alipay, WeChat Pay, and Stripe for user self-service top-up, no separate payment service needed ([Payment Setup](#payment))
- **Admin Dashboard** - Web interface for monitoring and management
- **External System Integration** - Embed external systems (e.g. ticketing) via iframe to extend the admin dashboard
## Payment
Sub2API includes the payment system in the main service. No standalone payment service or separate payment guide is required.
- Supported providers: EasyPay, Alipay, WeChat Pay, Stripe
- The frontend keeps user-facing methods unified; admins choose the backing source in `Admin -> Settings -> Payment`
- Callback URLs are generated from the site domain when configuring providers
## ❤️ Sponsors
> [Want to appear here?](mailto:support@pincc.ai)
......@@ -109,7 +117,7 @@ Community projects that extend or integrate with Sub2API:
| Project | Description | Features |
|---------|-------------|----------|
| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~Self-service payment system~~ | **Now Built-in** — Payment is now integrated into Sub2API, no separate deployment needed. See [Payment Configuration Guide](docs/PAYMENT.md) |
| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~Self-service payment system~~ | **Now Built-in** — Payment is now integrated into Sub2API, no separate deployment needed. See [Payment Setup](#payment) |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
## Tech Stack
......
......@@ -41,10 +41,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
- **智能调度** - 智能账号选择,支持粘性会话
- **并发控制** - 用户级和账号级并发限制
- **速率限制** - 可配置的请求和 Token 速率限制
- **内置支付系统** - 支持 EasyPay 易支付、支付宝官方、微信官方、Stripe,用户自助充值,无需独立部署支付服务([配置指南](docs/PAYMENT_CN.md))
- **内置支付系统** - 支持 EasyPay 易支付、支付宝官方、微信官方、Stripe,用户自助充值,无需独立部署支付服务([支付说明](#支付))
- **管理后台** - Web 界面进行监控和管理
- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如工单等),扩展管理后台功能
## 支付
Sub2API 已将支付系统集成到主服务中,无需独立支付服务,也不再依赖单独的支付配置文档。
- 支持服务商:EasyPay 易支付、支付宝官方、微信官方、Stripe
- 前台统一展示用户可见支付方式,管理员在 `管理后台 -> 设置 -> 支付` 里选择对应来源
- 添加服务商时会基于站点域名生成回调地址
## ❤️ 赞助商
> [想出现在这里?](mailto:support@pincc.ai)
......@@ -108,7 +116,7 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
| 项目 | 说明 | 功能 |
|------|------|------|
| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~自助支付系统~~ | **已内置** — 支付功能已集成到 Sub2API 中,无需独立部署。详见 [支付配置指南](docs/PAYMENT_CN.md) |
| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~自助支付系统~~ | **已内置** — 支付功能已集成到 Sub2API 中,无需独立部署。详见 [支付说明](#支付) |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 |
## 技术栈
......
......@@ -42,10 +42,18 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
- **スマートスケジューリング** - スティッキーセッション付きのインテリジェントなアカウント選択
- **同時実行制御** - ユーザーごと・アカウントごとの同時実行数制限
- **レート制限** - 設定可能なリクエスト数およびトークンレート制限
- **内蔵決済システム** - EasyPay、Alipay、WeChat Pay、Stripe に対応。ユーザーのセルフサービスチャージが可能で、別途決済サービスのデプロイは不要([設定ガイド](docs/PAYMENT.md)
- **内蔵決済システム** - EasyPay、Alipay、WeChat Pay、Stripe に対応。ユーザーのセルフサービスチャージが可能で、別途決済サービスのデプロイは不要([決済案内](#決済)
- **管理ダッシュボード** - 監視・管理のための Web インターフェース
- **外部システム連携** - 外部システム(チケット管理など)を iframe 経由で管理ダッシュボードに埋め込み可能
## 決済
Sub2API の決済機能は本体に統合されています。独立した決済サービスや別個の決済ガイドは不要です。
- 対応プロバイダー: EasyPay、Alipay、WeChat Pay、Stripe
- フロントエンドではユーザー向け決済方法を統一表示し、管理者は `管理画面 -> 設定 -> 決済` で実際の接続先を選択します
- プロバイダー設定時のコールバック URL はサイトドメインから自動生成されます
## ❤️ スポンサー
> [こちらに掲載しませんか?](mailto:support@pincc.ai)
......@@ -108,7 +116,7 @@ Sub2API を拡張・統合するコミュニティプロジェクト:
| プロジェクト | 説明 | 機能 |
|---------|-------------|----------|
| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~セルフサービス決済システム~~ | **内蔵済み** — 決済機能は Sub2API に統合されました。別途デプロイは不要です。[決済設定ガイド](docs/PAYMENT.md)をご参照ください |
| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~セルフサービス決済システム~~ | **内蔵済み** — 決済機能は Sub2API に統合されました。別途デプロイは不要です。[決済案内](#決済)をご参照ください |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | モバイル管理コンソール | ユーザー管理、アカウント管理、監視ダッシュボード、マルチバックエンド切り替えが可能なクロスプラットフォームアプリ(iOS/Android/Web)。Expo + React Native で構築 |
## 技術スタック
......
package migrate
import (
"testing"
"entgo.io/ent/dialect/entsql"
entschema "entgo.io/ent/dialect/sql/schema"
"github.com/stretchr/testify/require"
)
func TestAuthIdentityFoundationForeignKeyOnDeleteActions(t *testing.T) {
require.Equal(
t,
entschema.Cascade,
findForeignKeyBySymbol(t, AuthIdentitiesTable, "auth_identities_users_auth_identities").OnDelete,
)
require.Equal(
t,
entschema.Cascade,
findForeignKeyBySymbol(t, AuthIdentityChannelsTable, "auth_identity_channels_auth_identities_channels").OnDelete,
)
require.Equal(
t,
entschema.Cascade,
findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_pending_auth_sessions_adoption_decision").OnDelete,
)
require.Equal(
t,
entschema.SetNull,
findForeignKeyBySymbol(t, PendingAuthSessionsTable, "pending_auth_sessions_users_pending_auth_sessions").OnDelete,
)
require.Equal(
t,
entschema.SetNull,
findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_auth_identities_adoption_decisions").OnDelete,
)
}
func TestPaymentOrdersOutTradeNoPartialUniqueIndex(t *testing.T) {
idx := findIndexByName(t, PaymentOrdersTable, "paymentorder_out_trade_no")
require.True(t, idx.Unique)
require.Len(t, idx.Columns, 1)
require.Equal(t, "out_trade_no", idx.Columns[0].Name)
require.NotNil(t, idx.Annotation)
require.Equal(t, (&entsql.IndexAnnotation{Where: "out_trade_no <> ''"}).Where, idx.Annotation.Where)
}
func findForeignKeyBySymbol(t *testing.T, table *entschema.Table, symbol string) *entschema.ForeignKey {
t.Helper()
for _, fk := range table.ForeignKeys {
if fk.Symbol == symbol {
return fk
}
}
require.Failf(t, "missing foreign key", "table %s should include foreign key %s", table.Name, symbol)
return nil
}
func findIndexByName(t *testing.T, table *entschema.Table, name string) *entschema.Index {
t.Helper()
for _, idx := range table.Indexes {
if idx.Name == name {
return idx
}
}
require.Failf(t, "missing index", "table %s should include index %s", table.Name, name)
return nil
}
......@@ -361,7 +361,7 @@ var (
Symbol: "auth_identities_users_auth_identities",
Columns: []*schema.Column{AuthIdentitiesColumns[9]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
OnDelete: schema.Cascade,
},
},
Indexes: []*schema.Index{
......@@ -405,7 +405,7 @@ var (
Symbol: "auth_identity_channels_auth_identities_channels",
Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
OnDelete: schema.NoAction,
OnDelete: schema.Cascade,
},
},
Indexes: []*schema.Index{
......@@ -595,7 +595,7 @@ var (
Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision",
Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]},
OnDelete: schema.NoAction,
OnDelete: schema.Cascade,
},
},
Indexes: []*schema.Index{
......@@ -692,8 +692,11 @@ var (
Indexes: []*schema.Index{
{
Name: "paymentorder_out_trade_no",
Unique: false,
Unique: true,
Columns: []*schema.Column{PaymentOrdersColumns[8]},
Annotation: &entsql.IndexAnnotation{
Where: "out_trade_no <> ''",
},
},
{
Name: "paymentorder_user_id",
......
......@@ -79,7 +79,8 @@ func (AuthIdentity) Edges() []ent.Edge {
Field("user_id").
Required().
Unique(),
edge.To("channels", AuthIdentityChannel.Type),
edge.To("channels", AuthIdentityChannel.Type).
Annotations(entsql.OnDelete(entsql.Cascade)),
edge.To("adoption_decisions", IdentityAdoptionDecision.Type),
}
}
......
......@@ -3,7 +3,9 @@ package schema
import (
"testing"
"entgo.io/ent"
"entgo.io/ent/entc/load"
"entgo.io/ent/schema/field"
"github.com/stretchr/testify/require"
)
......@@ -74,6 +76,17 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
userSchema := requireSchema(t, schemas, "User")
requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at")
signupSource := requireSchemaField(t, userSchema, "signup_source")
require.Equal(t, field.TypeString, signupSource.Info.Type)
require.True(t, signupSource.Default)
require.Equal(t, "email", signupSource.DefaultValue)
require.Equal(t, 1, signupSource.Validators)
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} {
require.NoError(t, validator(value))
}
require.Error(t, validator("github"))
}
func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
......@@ -98,6 +111,37 @@ func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) {
}
}
func requireSchemaField(t *testing.T, schema *load.Schema, name string) *load.Field {
t.Helper()
for _, schemaField := range schema.Fields {
if schemaField.Name == name {
return schemaField
}
}
require.Failf(t, "missing schema field", "schema %s should include field %s", schema.Name, name)
return nil
}
func requireStringFieldValidator(t *testing.T, fields []ent.Field, name string) func(string) error {
t.Helper()
for _, entField := range fields {
descriptor := entField.Descriptor()
if descriptor.Name != name {
continue
}
require.NotEmpty(t, descriptor.Validators, "field %s should include a validator", name)
validator, ok := descriptor.Validators[0].(func(string) error)
require.True(t, ok, "field %s validator should be func(string) error", name)
return validator
}
require.Failf(t, "missing field validator", "schema should include field %s", name)
return nil
}
func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) {
t.Helper()
......
......@@ -185,7 +185,9 @@ func (PaymentOrder) Edges() []ent.Edge {
func (PaymentOrder) Indexes() []ent.Index {
return []ent.Index{
index.Fields("out_trade_no"),
index.Fields("out_trade_no").
Unique().
Annotations(entsql.IndexWhere("out_trade_no <> ''")),
index.Fields("user_id"),
index.Fields("status"),
index.Fields("expires_at"),
......
......@@ -119,6 +119,7 @@ func (PendingAuthSession) Edges() []ent.Edge {
Field("target_user_id").
Unique(),
edge.To("adoption_decision", IdentityAdoptionDecision.Type).
Annotations(entsql.OnDelete(entsql.Cascade)).
Unique(),
}
}
......
package schema
import (
"fmt"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/domain"
......@@ -73,7 +75,14 @@ func (User) Fields() []ent.Field {
Optional().
Nillable(),
field.String("signup_source").
MaxLen(20).
Validate(func(value string) error {
switch value {
case "email", "linuxdo", "wechat", "oidc":
return nil
default:
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc")
}
}).
Default("email"),
field.Time("last_login_at").
Optional().
......@@ -115,7 +124,8 @@ func (User) Edges() []ent.Edge {
edge.To("attribute_values", UserAttributeValue.Type),
edge.To("promo_code_usages", PromoCodeUsage.Type),
edge.To("payment_orders", PaymentOrder.Type),
edge.To("auth_identities", AuthIdentity.Type),
edge.To("auth_identities", AuthIdentity.Type).
Annotations(entsql.OnDelete(entsql.Cascade)),
edge.To("pending_auth_sessions", PendingAuthSession.Type),
}
}
......
......@@ -70,6 +70,7 @@ type Config struct {
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
......@@ -190,6 +191,25 @@ type LinuxDoConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type WeChatConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
AppID string `mapstructure:"app_id"`
AppSecret string `mapstructure:"app_secret"`
OpenAppID string `mapstructure:"open_app_id"`
OpenAppSecret string `mapstructure:"open_app_secret"`
MPAppID string `mapstructure:"mp_app_id"`
MPAppSecret string `mapstructure:"mp_app_secret"`
MobileAppID string `mapstructure:"mobile_app_id"`
MobileAppSecret string `mapstructure:"mobile_app_secret"`
OpenEnabled bool `mapstructure:"open_enabled"`
MPEnabled bool `mapstructure:"mp_enabled"`
MobileEnabled bool `mapstructure:"mobile_enabled"`
Mode string `mapstructure:"mode"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"`
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
}
type OIDCConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
......@@ -207,6 +227,8 @@ type OIDCConnectConfig struct {
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
ValidateIDToken bool `mapstructure:"validate_id_token"`
UsePKCEExplicit bool `mapstructure:"-" yaml:"-"`
ValidateIDTokenExplicit bool `mapstructure:"-" yaml:"-"`
AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
......@@ -218,6 +240,225 @@ type OIDCConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
const (
defaultWeChatConnectMode = "open"
defaultWeChatConnectScopes = "snsapi_login"
defaultWeChatConnectFrontendRedirect = "/auth/wechat/callback"
)
func firstNonEmptyString(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
func normalizeWeChatConnectMode(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "mp":
return "mp"
case "mobile":
return "mobile"
default:
return defaultWeChatConnectMode
}
}
func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string {
mode = normalizeWeChatConnectMode(mode)
switch mode {
case "open":
if openEnabled {
return "open"
}
case "mp":
if mpEnabled {
return "mp"
}
case "mobile":
if mobileEnabled {
return "mobile"
}
}
switch {
case openEnabled:
return "open"
case mpEnabled:
return "mp"
case mobileEnabled:
return "mobile"
default:
return mode
}
}
func defaultWeChatConnectScopesForMode(mode string) string {
switch normalizeWeChatConnectMode(mode) {
case "mp":
return "snsapi_userinfo"
case "mobile":
return ""
default:
return defaultWeChatConnectScopes
}
}
func normalizeWeChatConnectScopes(raw, mode string) string {
switch normalizeWeChatConnectMode(mode) {
case "mp":
switch strings.TrimSpace(raw) {
case "snsapi_base":
return "snsapi_base"
case "snsapi_userinfo":
return "snsapi_userinfo"
default:
return defaultWeChatConnectScopesForMode(mode)
}
case "mobile":
return ""
default:
return defaultWeChatConnectScopes
}
}
func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool {
if viper.InConfig(configKey) {
return false
}
_, hasNewEnv := os.LookupEnv(envKey)
return !hasNewEnv
}
func hasExplicitConfigOrEnv(configKey, envKey string) bool {
if viper.InConfig(configKey) {
return true
}
_, ok := os.LookupEnv(envKey)
return ok
}
func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) {
if cfg == nil {
return
}
legacyOpenAppID := ""
if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_id", "WECHAT_CONNECT_OPEN_APP_ID") &&
shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") {
legacyOpenAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID"))
if legacyOpenAppID != "" {
cfg.OpenAppID = legacyOpenAppID
}
}
legacyOpenAppSecret := ""
if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_secret", "WECHAT_CONNECT_OPEN_APP_SECRET") &&
shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") {
legacyOpenAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET"))
if legacyOpenAppSecret != "" {
cfg.OpenAppSecret = legacyOpenAppSecret
}
}
legacyMPAppID := ""
if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_id", "WECHAT_CONNECT_MP_APP_ID") &&
shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") {
legacyMPAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
if legacyMPAppID != "" {
cfg.MPAppID = legacyMPAppID
}
}
legacyMPAppSecret := ""
if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_secret", "WECHAT_CONNECT_MP_APP_SECRET") &&
shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") {
legacyMPAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
if legacyMPAppSecret != "" {
cfg.MPAppSecret = legacyMPAppSecret
}
}
if shouldApplyLegacyWeChatEnv("wechat_connect.frontend_redirect_url", "WECHAT_CONNECT_FRONTEND_REDIRECT_URL") {
if legacyFrontend := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")); legacyFrontend != "" {
cfg.FrontendRedirectURL = legacyFrontend
}
}
hasLegacyOpen := legacyOpenAppID != "" && legacyOpenAppSecret != ""
hasLegacyMP := legacyMPAppID != "" && legacyMPAppSecret != ""
if shouldApplyLegacyWeChatEnv("wechat_connect.enabled", "WECHAT_CONNECT_ENABLED") && (hasLegacyOpen || hasLegacyMP) {
cfg.Enabled = true
}
if shouldApplyLegacyWeChatEnv("wechat_connect.open_enabled", "WECHAT_CONNECT_OPEN_ENABLED") && hasLegacyOpen {
cfg.OpenEnabled = true
}
if shouldApplyLegacyWeChatEnv("wechat_connect.mp_enabled", "WECHAT_CONNECT_MP_ENABLED") && hasLegacyMP {
cfg.MPEnabled = true
}
if shouldApplyLegacyWeChatEnv("wechat_connect.mode", "WECHAT_CONNECT_MODE") {
switch {
case hasLegacyMP && !hasLegacyOpen:
cfg.Mode = "mp"
case hasLegacyOpen:
cfg.Mode = "open"
}
}
if shouldApplyLegacyWeChatEnv("wechat_connect.scopes", "WECHAT_CONNECT_SCOPES") {
switch {
case hasLegacyMP && !hasLegacyOpen:
cfg.Scopes = defaultWeChatConnectScopesForMode("mp")
case hasLegacyOpen:
cfg.Scopes = defaultWeChatConnectScopesForMode("open")
}
}
}
func normalizeWeChatConnectConfig(cfg *WeChatConnectConfig) {
if cfg == nil {
return
}
cfg.AppID = strings.TrimSpace(cfg.AppID)
cfg.AppSecret = strings.TrimSpace(cfg.AppSecret)
cfg.OpenAppID = strings.TrimSpace(cfg.OpenAppID)
cfg.OpenAppSecret = strings.TrimSpace(cfg.OpenAppSecret)
cfg.MPAppID = strings.TrimSpace(cfg.MPAppID)
cfg.MPAppSecret = strings.TrimSpace(cfg.MPAppSecret)
cfg.MobileAppID = strings.TrimSpace(cfg.MobileAppID)
cfg.MobileAppSecret = strings.TrimSpace(cfg.MobileAppSecret)
cfg.Mode = normalizeWeChatConnectMode(cfg.Mode)
cfg.RedirectURL = strings.TrimSpace(cfg.RedirectURL)
cfg.FrontendRedirectURL = strings.TrimSpace(cfg.FrontendRedirectURL)
cfg.AppID = firstNonEmptyString(cfg.AppID, cfg.OpenAppID, cfg.MPAppID, cfg.MobileAppID)
cfg.AppSecret = firstNonEmptyString(cfg.AppSecret, cfg.OpenAppSecret, cfg.MPAppSecret, cfg.MobileAppSecret)
cfg.OpenAppID = firstNonEmptyString(cfg.OpenAppID, cfg.AppID)
cfg.OpenAppSecret = firstNonEmptyString(cfg.OpenAppSecret, cfg.AppSecret)
cfg.MPAppID = firstNonEmptyString(cfg.MPAppID, cfg.AppID)
cfg.MPAppSecret = firstNonEmptyString(cfg.MPAppSecret, cfg.AppSecret)
cfg.MobileAppID = firstNonEmptyString(cfg.MobileAppID, cfg.AppID)
cfg.MobileAppSecret = firstNonEmptyString(cfg.MobileAppSecret, cfg.AppSecret)
if !cfg.OpenEnabled && !cfg.MPEnabled && !cfg.MobileEnabled && cfg.Enabled {
switch cfg.Mode {
case "mp":
cfg.MPEnabled = true
case "mobile":
cfg.MobileEnabled = true
default:
cfg.OpenEnabled = true
}
}
cfg.Mode = normalizeWeChatConnectStoredMode(cfg.OpenEnabled, cfg.MPEnabled, cfg.MobileEnabled, cfg.Mode)
cfg.Scopes = normalizeWeChatConnectScopes(cfg.Scopes, cfg.Mode)
if cfg.FrontendRedirectURL == "" {
cfg.FrontendRedirectURL = defaultWeChatConnectFrontendRedirect
}
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
......@@ -1012,6 +1253,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
applyLegacyWeChatConnectEnvCompatibility(&cfg.WeChat)
normalizeWeChatConnectConfig(&cfg.WeChat)
cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName)
cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID)
cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret)
......@@ -1029,6 +1272,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
cfg.OIDC.UsePKCEExplicit = hasExplicitConfigOrEnv("oidc_connect.use_pkce", "OIDC_CONNECT_USE_PKCE")
cfg.OIDC.ValidateIDTokenExplicit = hasExplicitConfigOrEnv("oidc_connect.validate_id_token", "OIDC_CONNECT_VALIDATE_ID_TOKEN")
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
......@@ -1207,6 +1452,24 @@ func setDefaults() {
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
// WeChat Connect OAuth 登录
viper.SetDefault("wechat_connect.enabled", false)
viper.SetDefault("wechat_connect.app_id", "")
viper.SetDefault("wechat_connect.app_secret", "")
viper.SetDefault("wechat_connect.open_app_id", "")
viper.SetDefault("wechat_connect.open_app_secret", "")
viper.SetDefault("wechat_connect.mp_app_id", "")
viper.SetDefault("wechat_connect.mp_app_secret", "")
viper.SetDefault("wechat_connect.mobile_app_id", "")
viper.SetDefault("wechat_connect.mobile_app_secret", "")
viper.SetDefault("wechat_connect.open_enabled", false)
viper.SetDefault("wechat_connect.mp_enabled", false)
viper.SetDefault("wechat_connect.mobile_enabled", false)
viper.SetDefault("wechat_connect.mode", defaultWeChatConnectMode)
viper.SetDefault("wechat_connect.scopes", defaultWeChatConnectScopes)
viper.SetDefault("wechat_connect.redirect_url", "")
viper.SetDefault("wechat_connect.frontend_redirect_url", defaultWeChatConnectFrontendRedirect)
// Generic OIDC OAuth 登录
viper.SetDefault("oidc_connect.enabled", false)
viper.SetDefault("oidc_connect.provider_name", "OIDC")
......@@ -1222,7 +1485,7 @@ func setDefaults() {
viper.SetDefault("oidc_connect.redirect_url", "")
viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback")
viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post")
viper.SetDefault("oidc_connect.use_pkce", false)
viper.SetDefault("oidc_connect.use_pkce", true)
viper.SetDefault("oidc_connect.validate_id_token", true)
viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256")
viper.SetDefault("oidc_connect.clock_skew_seconds", 120)
......@@ -1613,9 +1876,6 @@ func (c *Config) Validate() error {
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
}
if c.LinuxDo.Enabled {
if !c.LinuxDo.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
}
......@@ -1667,13 +1927,46 @@ func (c *Config) Validate() error {
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
if c.OIDC.Enabled {
if !c.OIDC.UsePKCE {
return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.enabled=true")
if c.WeChat.Enabled {
weChat := c.WeChat
normalizeWeChatConnectConfig(&weChat)
if weChat.OpenEnabled {
if strings.TrimSpace(weChat.OpenAppID) == "" {
return fmt.Errorf("wechat_connect.open_app_id is required when wechat_connect.open_enabled=true")
}
if strings.TrimSpace(weChat.OpenAppSecret) == "" {
return fmt.Errorf("wechat_connect.open_app_secret is required when wechat_connect.open_enabled=true")
}
if !c.OIDC.ValidateIDToken {
return fmt.Errorf("oidc_connect.validate_id_token must be true when oidc_connect.enabled=true")
}
if weChat.MPEnabled {
if strings.TrimSpace(weChat.MPAppID) == "" {
return fmt.Errorf("wechat_connect.mp_app_id is required when wechat_connect.mp_enabled=true")
}
if strings.TrimSpace(weChat.MPAppSecret) == "" {
return fmt.Errorf("wechat_connect.mp_app_secret is required when wechat_connect.mp_enabled=true")
}
}
if weChat.MobileEnabled {
if strings.TrimSpace(weChat.MobileAppID) == "" {
return fmt.Errorf("wechat_connect.mobile_app_id is required when wechat_connect.mobile_enabled=true")
}
if strings.TrimSpace(weChat.MobileAppSecret) == "" {
return fmt.Errorf("wechat_connect.mobile_app_secret is required when wechat_connect.mobile_enabled=true")
}
}
if v := strings.TrimSpace(weChat.RedirectURL); v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil {
return fmt.Errorf("wechat_connect.redirect_url invalid: %w", err)
}
warnIfInsecureURL("wechat_connect.redirect_url", v)
}
if err := ValidateFrontendRedirectURL(weChat.FrontendRedirectURL); err != nil {
return fmt.Errorf("wechat_connect.frontend_redirect_url invalid: %w", err)
}
warnIfInsecureURL("wechat_connect.frontend_redirect_url", weChat.FrontendRedirectURL)
}
if c.OIDC.Enabled {
if strings.TrimSpace(c.OIDC.ClientID) == "" {
return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true")
}
......
......@@ -225,6 +225,52 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
}
}
func TestLoadWeChatConnectConfigFromLegacyEnv(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret")
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/legacy-callback")
cfg, err := Load()
require.NoError(t, err)
require.True(t, cfg.WeChat.Enabled)
require.True(t, cfg.WeChat.OpenEnabled)
require.True(t, cfg.WeChat.MPEnabled)
require.False(t, cfg.WeChat.MobileEnabled)
require.Equal(t, "open", cfg.WeChat.Mode)
require.Equal(t, "wx-open-app", cfg.WeChat.OpenAppID)
require.Equal(t, "wx-open-secret", cfg.WeChat.OpenAppSecret)
require.Equal(t, "wx-mp-app", cfg.WeChat.MPAppID)
require.Equal(t, "wx-mp-secret", cfg.WeChat.MPAppSecret)
require.Equal(t, "/auth/wechat/legacy-callback", cfg.WeChat.FrontendRedirectURL)
}
func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
require.NoError(t, err)
require.True(t, cfg.OIDC.UsePKCE)
require.True(t, cfg.OIDC.ValidateIDToken)
require.False(t, cfg.OIDC.UsePKCEExplicit)
require.False(t, cfg.OIDC.ValidateIDTokenExplicit)
}
func TestLoadExplicitOIDCSecurityDefaultsFromEnvMarksFlagsExplicit(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("OIDC_CONNECT_USE_PKCE", "false")
t.Setenv("OIDC_CONNECT_VALIDATE_ID_TOKEN", "false")
cfg, err := Load()
require.NoError(t, err)
require.False(t, cfg.OIDC.UsePKCE)
require.False(t, cfg.OIDC.ValidateIDToken)
require.True(t, cfg.OIDC.UsePKCEExplicit)
require.True(t, cfg.OIDC.ValidateIDTokenExplicit)
}
func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
resetViperWithJWTSecret(t)
......@@ -346,7 +392,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
}
}
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
func TestValidateLinuxDoAllowsDisablingPKCEForCompatibility(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
......@@ -363,11 +409,8 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
cfg.LinuxDo.UsePKCE = false
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
}
if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
if err != nil {
t.Fatalf("Validate() expected LinuxDo config without PKCE to pass for compatibility, got: %v", err)
}
}
......@@ -427,6 +470,35 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T
}
}
func TestValidateOIDCAllowsExplicitCompatibilityOverridesForPKCEAndIDTokenValidation(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.OIDC.Enabled = true
cfg.OIDC.ClientID = "oidc-client"
cfg.OIDC.ClientSecret = "oidc-secret"
cfg.OIDC.IssuerURL = "https://issuer.example.com"
cfg.OIDC.AuthorizeURL = "https://issuer.example.com/auth"
cfg.OIDC.TokenURL = "https://issuer.example.com/token"
cfg.OIDC.UserInfoURL = "https://issuer.example.com/userinfo"
cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
cfg.OIDC.Scopes = "openid email profile"
cfg.OIDC.UsePKCE = false
cfg.OIDC.ValidateIDToken = false
cfg.OIDC.JWKSURL = ""
cfg.OIDC.AllowedSigningAlgs = ""
err = cfg.Validate()
if err != nil {
t.Fatalf("Validate() expected OIDC config without PKCE/id_token validation to pass for compatibility, got: %v", err)
}
}
func TestLoadDefaultDashboardCacheConfig(t *testing.T) {
resetViperWithJWTSecret(t)
......
......@@ -304,8 +304,8 @@ type UpdateSettingsRequest struct {
OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"`
OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"`
OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"`
OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"`
OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"`
OIDCConnectUsePKCE *bool `json:"oidc_connect_use_pkce"`
OIDCConnectValidateIDToken *bool `json:"oidc_connect_validate_id_token"`
OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"`
OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"`
OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"`
......@@ -565,6 +565,15 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.WeChatConnectScopes = strings.TrimSpace(req.WeChatConnectScopes)
req.WeChatConnectRedirectURL = strings.TrimSpace(req.WeChatConnectRedirectURL)
req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(req.WeChatConnectFrontendRedirectURL)
req.WeChatConnectAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectAppID, previousSettings.WeChatConnectAppID))
req.WeChatConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectRedirectURL, previousSettings.WeChatConnectRedirectURL))
req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectFrontendRedirectURL, previousSettings.WeChatConnectFrontendRedirectURL))
if req.WeChatConnectMode == "" {
req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(previousSettings.WeChatConnectMode))
}
if req.WeChatConnectScopes == "" {
req.WeChatConnectScopes = strings.TrimSpace(previousSettings.WeChatConnectScopes)
}
if req.WeChatConnectMPEnabled && req.WeChatConnectMobileEnabled {
response.BadRequest(c, "WeChat Official Account and Mobile App cannot be enabled at the same time")
......@@ -598,9 +607,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
req.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppID, req.WeChatConnectAppID))
req.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMPAppID, req.WeChatConnectAppID))
req.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMobileAppID, req.WeChatConnectAppID))
req.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectOpenAppID, previousSettings.WeChatConnectAppID))
req.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMPAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMPAppID, previousSettings.WeChatConnectAppID))
req.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMobileAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMobileAppID, previousSettings.WeChatConnectAppID))
if req.WeChatConnectOpenAppSecret == "" {
req.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectOpenAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
......@@ -653,8 +662,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode)
}
}
if req.WeChatConnectOpenEnabled || req.WeChatConnectMPEnabled {
if req.WeChatConnectRedirectURL == "" {
response.BadRequest(c, "WeChat Redirect URL is required when enabled")
response.BadRequest(c, "WeChat Redirect URL is required when web oauth is enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil {
......@@ -669,8 +679,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
}
}
// Generic OIDC 参数验证
oidcUsePKCE, oidcValidateIDToken, err := h.settingService.OIDCSecurityWriteDefaults(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
if req.OIDCConnectEnabled {
req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName)
req.OIDCConnectClientID = strings.TrimSpace(req.OIDCConnectClientID)
......@@ -689,10 +705,35 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(req.OIDCConnectUserInfoEmailPath)
req.OIDCConnectUserInfoIDPath = strings.TrimSpace(req.OIDCConnectUserInfoIDPath)
req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(req.OIDCConnectUserInfoUsernamePath)
if req.OIDCConnectProviderName == "" {
req.OIDCConnectProviderName = "OIDC"
req.OIDCConnectProviderName = strings.TrimSpace(firstNonEmpty(req.OIDCConnectProviderName, previousSettings.OIDCConnectProviderName, "OIDC"))
req.OIDCConnectClientID = strings.TrimSpace(firstNonEmpty(req.OIDCConnectClientID, previousSettings.OIDCConnectClientID))
req.OIDCConnectIssuerURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectIssuerURL, previousSettings.OIDCConnectIssuerURL))
req.OIDCConnectDiscoveryURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectDiscoveryURL, previousSettings.OIDCConnectDiscoveryURL))
req.OIDCConnectAuthorizeURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAuthorizeURL, previousSettings.OIDCConnectAuthorizeURL))
req.OIDCConnectTokenURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenURL, previousSettings.OIDCConnectTokenURL))
req.OIDCConnectUserInfoURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoURL, previousSettings.OIDCConnectUserInfoURL))
req.OIDCConnectJWKSURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectJWKSURL, previousSettings.OIDCConnectJWKSURL))
req.OIDCConnectScopes = strings.TrimSpace(firstNonEmpty(req.OIDCConnectScopes, previousSettings.OIDCConnectScopes, "openid email profile"))
req.OIDCConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectRedirectURL, previousSettings.OIDCConnectRedirectURL))
req.OIDCConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectFrontendRedirectURL, previousSettings.OIDCConnectFrontendRedirectURL, "/auth/oidc/callback"))
req.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenAuthMethod, previousSettings.OIDCConnectTokenAuthMethod, "client_secret_post")))
req.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAllowedSigningAlgs, previousSettings.OIDCConnectAllowedSigningAlgs, "RS256,ES256,PS256"))
req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoEmailPath, previousSettings.OIDCConnectUserInfoEmailPath))
req.OIDCConnectUserInfoIDPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoIDPath, previousSettings.OIDCConnectUserInfoIDPath))
req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoUsernamePath, previousSettings.OIDCConnectUserInfoUsernamePath))
if req.OIDCConnectUsePKCE != nil {
oidcUsePKCE = *req.OIDCConnectUsePKCE
}
if req.OIDCConnectValidateIDToken != nil {
oidcValidateIDToken = *req.OIDCConnectValidateIDToken
}
if req.OIDCConnectClockSkewSeconds == 0 {
req.OIDCConnectClockSkewSeconds = previousSettings.OIDCConnectClockSkewSeconds
if req.OIDCConnectClockSkewSeconds == 0 {
req.OIDCConnectClockSkewSeconds = 120
}
}
if req.OIDCConnectClientID == "" {
response.BadRequest(c, "OIDC Client ID is required when enabled")
return
......@@ -749,14 +790,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC scopes must contain openid")
return
}
if !req.OIDCConnectUsePKCE {
response.BadRequest(c, "OIDC PKCE must be enabled")
return
}
if !req.OIDCConnectValidateIDToken {
response.BadRequest(c, "OIDC ID Token validation must be enabled")
return
}
switch req.OIDCConnectTokenAuthMethod {
case "", "client_secret_post", "client_secret_basic", "none":
default:
......@@ -767,7 +800,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
return
}
if req.OIDCConnectAllowedSigningAlgs == "" {
if oidcValidateIDToken && req.OIDCConnectAllowedSigningAlgs == "" {
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
return
}
......@@ -1048,8 +1081,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: req.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: req.OIDCConnectValidateIDToken,
OIDCConnectUsePKCE: oidcUsePKCE,
OIDCConnectValidateIDToken: oidcValidateIDToken,
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
......
......@@ -247,6 +247,163 @@ func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedS
require.Equal(t, true, data["openai_advanced_scheduler_enabled"])
}
func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodSource(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
service.SettingPaymentVisibleMethodAlipayEnabled: "true",
service.SettingPaymentVisibleMethodAlipaySource: "",
service.SettingPaymentVisibleMethodWxpayEnabled: "false",
service.SettingPaymentVisibleMethodWxpaySource: "",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": false,
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "", repo.values[service.SettingPaymentVisibleMethodAlipaySource])
require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled])
}
func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFlags(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyOIDCConnectEnabled: "true",
service.SettingKeyOIDCConnectProviderName: "OIDC",
service.SettingKeyOIDCConnectClientID: "oidc-client",
service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
service.SettingKeyOIDCConnectScopes: "openid email profile",
service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectUsePKCE: "true",
service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
"oidc_connect_enabled": true,
"oidc_connect_use_pkce": false,
"oidc_connect_validate_id_token": false,
"oidc_connect_allowed_signing_algs": "",
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, false, data["oidc_connect_use_pkce"])
require.Equal(t, false, data["oidc_connect_validate_id_token"])
}
func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaultsOnLegacyUpgrade(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyOIDCConnectEnabled: "true",
service.SettingKeyOIDCConnectProviderName: "OIDC",
service.SettingKeyOIDCConnectClientID: "oidc-client",
service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
service.SettingKeyOIDCConnectScopes: "openid email profile",
service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
service.SettingKeyOIDCConnectRequireEmailVerified: "false",
service.SettingKeyOIDCConnectUserInfoEmailPath: "",
service.SettingKeyOIDCConnectUserInfoIDPath: "",
service.SettingKeyOIDCConnectUserInfoUsernamePath: "",
},
}
svc := service.NewSettingService(repo, &config.Config{
Default: config.DefaultConfig{UserConcurrency: 5},
OIDC: config.OIDCConnectConfig{
Enabled: true,
ProviderName: "OIDC",
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: "https://issuer.example.com/auth",
TokenURL: "https://issuer.example.com/token",
UserInfoURL: "https://issuer.example.com/userinfo",
JWKSURL: "https://issuer.example.com/jwks",
Scopes: "openid email profile",
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
ValidateIDToken: true,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
},
})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
"oidc_connect_enabled": true,
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
}
func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
......
......@@ -37,6 +37,7 @@ func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
"avatar_url": "https://cdn.example.com/linuxdo.png",
},
},
},
......
......@@ -78,9 +78,24 @@ type AuthResponse struct {
User *dto.User `json:"user"`
}
func ensureLoginUserActive(user *service.User) error {
if user == nil {
return infraerrors.Unauthorized("INVALID_USER", "user not found")
}
if !user.IsActive() {
return service.ErrUserNotActive
}
return nil
}
// respondWithTokenPair 生成 Token 对并返回认证响应
// 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容)
func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
if err := ensureLoginUserActive(user); err != nil {
response.ErrorFrom(c, err)
return
}
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
slog.Error("failed to generate token pair", "error", err, "user_id", user.ID)
......@@ -293,6 +308,10 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := ensureLoginUserActive(user); err != nil {
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
response.ErrorFrom(c, err)
......@@ -678,6 +697,8 @@ func (h *AuthHandler) Logout(c *gin.Context) {
// 不影响登出流程
}
}
h.consumePendingOAuthSessionOnLogout(c)
clearOAuthLogoutCookies(c)
response.Success(c, LogoutResponse{
Message: "Logged out successfully",
......@@ -698,7 +719,7 @@ func (h *AuthHandler) RevokeAllSessions(c *gin.Context) {
return
}
if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil {
if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil {
slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err)
response.InternalError(c, "Failed to revoke sessions")
return
......
......@@ -123,13 +123,16 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
}
codeChallenge := ""
if cfg.UsePKCE {
verifier, err := oauth.GenerateCodeVerifier()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
return
}
codeChallenge := oauth.GenerateCodeChallenge(verifier)
codeChallenge = oauth.GenerateCodeChallenge(verifier)
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
......@@ -200,11 +203,14 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
intent, _ := readCookieDecoded(c, linuxDoOAuthIntentCookieName)
intent = normalizeOAuthIntent(intent)
codeVerifier, _ := readCookieDecoded(c, linuxDoOAuthVerifierCookie)
codeVerifier := ""
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
......@@ -292,24 +298,15 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
if existingIdentityUser != nil {
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identityKey,
TargetUserID: &user.ID,
TargetUserID: &existingIdentityUser.ID,
ResolvedEmail: existingIdentityUser.Email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
"redirect": redirectTo,
},
}); err != nil {
......@@ -358,15 +355,20 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri
}
userEntity, err := client.User.Query().
Where(dbuser.EmailEqualFold(email)).
Only(ctx)
Where(userNormalizedEmailPredicate(email)).
Order(dbent.Asc(dbuser.FieldID)).
All(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
}
return userEntity, nil
switch len(userEntity) {
case 0:
return nil, nil
case 1:
return userEntity[0], nil
default:
return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
}
}
func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
......@@ -414,9 +416,15 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
completionResponse["choice_reason"] = "force_email_on_signup"
}
var targetUserID *int64
if compatEmailUser != nil && compatEmailUser.ID > 0 {
targetUserID = &compatEmailUser.ID
}
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identity,
TargetUserID: targetUserID,
ResolvedEmail: resolvedChoiceEmail,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
......@@ -472,6 +480,15 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
......@@ -484,12 +501,16 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
client := h.entClient()
if client == nil {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
......@@ -497,17 +518,16 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, err)
if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
......@@ -546,7 +566,9 @@ func linuxDoExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
if strings.TrimSpace(codeVerifier) != "" {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
......@@ -699,8 +721,10 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
if strings.TrimSpace(codeChallenge) != "" {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil
......@@ -937,7 +961,19 @@ func clearOAuthBindAccessTokenCookie(c *gin.Context, secure bool) {
Value: "",
Path: oauthBindAccessTokenCookiePath,
MaxAge: -1,
HttpOnly: false,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func setOAuthBindAccessTokenCookie(c *gin.Context, token string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: oauthBindAccessTokenCookieName,
Value: url.QueryEscape(strings.TrimSpace(token)),
Path: oauthBindAccessTokenCookiePath,
MaxAge: linuxDoOAuthCookieMaxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
......@@ -1021,6 +1057,26 @@ func (h *AuthHandler) buildOAuthBindUserCookieFromContext(c *gin.Context) (strin
return buildOAuthBindUserCookieValue(*userID, h.oauthBindCookieSecret())
}
func (h *AuthHandler) PrepareOAuthBindAccessTokenCookie(c *gin.Context) {
const bearerPrefix = "Bearer "
authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
if !strings.HasPrefix(strings.ToLower(authHeader), strings.ToLower(bearerPrefix)) {
response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
return
}
token := strings.TrimSpace(authHeader[len(bearerPrefix):])
if token == "" {
response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
return
}
setOAuthBindAccessTokenCookie(c, token, isRequestHTTPS(c))
c.Status(http.StatusNoContent)
c.Writer.WriteHeaderNow()
}
func (h *AuthHandler) resolveOAuthBindTargetUserID(c *gin.Context) (*int64, error) {
if subject, ok := servermiddleware.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
return &subject.UserID, nil
......
......@@ -5,6 +5,7 @@ import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
......@@ -170,6 +171,80 @@ func TestLinuxDoOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
require.Equal(t, int64(42), userID)
}
func TestLinuxDoOAuthStartOmitsPKCEWhenDisabled(t *testing.T) {
handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: "https://connect.linux.do/oauth/authorize",
TokenURL: "https://connect.linux.do/oauth/token",
UserInfoURL: "https://connect.linux.do/api/user",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?redirect=/dashboard", nil)
handler.LinuxDoOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.NotContains(t, recorder.Header().Get("Location"), "code_challenge=")
require.Nil(t, findCookie(recorder.Result().Cookies(), linuxDoOAuthVerifierCookie))
}
func TestLinuxDoOAuthCallbackAllowsMissingVerifierWhenPKCEDisabled(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
require.NoError(t, r.ParseForm())
require.Empty(t, r.PostForm.Get("code_verifier"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"compat-subject","username":"linuxdo_user","name":"LinuxDo Display"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=linuxdo-code&state=state-123", nil)
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.LinuxDoOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
}
func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) {
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
......@@ -226,6 +301,27 @@ func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) {
require.Equal(t, -1, accessTokenCookie.MaxAge)
}
func TestPrepareOAuthBindAccessTokenCookieSetsHttpOnlyCookie(t *testing.T) {
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/bind-token", nil)
req.Header.Set("Authorization", "Bearer access-token-value")
c.Request = req
handler.PrepareOAuthBindAccessTokenCookie(c)
require.Equal(t, http.StatusNoContent, recorder.Code)
accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName)
require.NotNil(t, accessTokenCookie)
require.Equal(t, oauthBindAccessTokenCookiePath, accessTokenCookie.Path)
require.Equal(t, linuxDoOAuthCookieMaxAgeSec, accessTokenCookie.MaxAge)
require.True(t, accessTokenCookie.HttpOnly)
require.Equal(t, url.QueryEscape("access-token-value"), accessTokenCookie.Value)
}
func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
......@@ -305,10 +401,81 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.Equal(t, "/dashboard", completion["redirect"])
require.NotEmpty(t, completion["access_token"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
_, hasRefreshToken := completion["refresh_token"]
require.False(t, hasRefreshToken)
require.Nil(t, completion["error"])
}
func TestLinuxDoOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_disabled","name":"LinuxDo Disabled"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
})
t.Cleanup(func() { _ = client.Close() })
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(linuxDoSyntheticEmail("654")).
SetUsername("disabled-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusDisabled).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingUser.ID).
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("654").
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-disabled&state=state-disabled", nil)
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-disabled"))
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-disabled"))
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
c.Request = req
handler.LinuxDoOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
count, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, count)
}
func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
......@@ -341,7 +508,7 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("legacy@example.com").
SetEmail(" Legacy@Example.com ").
SetUsername("legacy-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
......@@ -372,16 +539,17 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent)
require.Nil(t, session.TargetUserID)
require.Equal(t, existingUser.Email, session.ResolvedEmail)
require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, strings.TrimSpace(existingUser.Email), session.ResolvedEmail)
require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.Equal(t, "/dashboard", completion["redirect"])
require.Equal(t, oauthPendingChoiceStep, completion["step"])
require.Equal(t, existingUser.Email, completion["email"])
require.Equal(t, existingUser.Email, completion["existing_account_email"])
require.Equal(t, strings.TrimSpace(existingUser.Email), completion["email"])
require.Equal(t, strings.TrimSpace(existingUser.Email), completion["existing_account_email"])
require.Equal(t, true, completion["existing_account_bindable"])
require.Equal(t, "compat_email_match", completion["choice_reason"])
_, hasAccessToken := completion["access_token"]
......@@ -658,6 +826,186 @@ func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *test
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-choice-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-choice-subject-1").
SetResolvedEmail("linuxdo-choice-subject-1@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-choice-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-no-adoption-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-subject-no-adoption").
SetResolvedEmail("linuxdo-subject-no-adoption@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-browser-no-adoption").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
"suggested_display_name": "LinuxDo Legacy",
"suggested_avatar_url": "https://cdn.example/linuxdo-legacy.png",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser-no-adoption")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.NotEmpty(t, responseData["access_token"])
require.NotEmpty(t, responseData["refresh_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(session.ResolvedEmail)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "linuxdo_user", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("linuxdo"),
authidentity.ProviderKeyEQ("linuxdo"),
authidentity.ProviderSubjectEQ("linuxdo-subject-no-adoption"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.False(t, decision.AdoptDisplayName)
require.False(t, decision.AdoptAvatar)
}
func TestCompleteLinuxDoOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
existingOwner, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingOwner.ID).
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-conflict-subject").
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-conflict-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-conflict-subject").
SetResolvedEmail("linuxdo-conflict-subject@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-conflict-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-conflict-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusConflict, recorder.Code)
payload := decodeJSONBody(t, recorder)
require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
userCount, err := client.User.Query().
Where(dbuser.EmailEQ("linuxdo-conflict-subject@linuxdo-connect.invalid")).
Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
t.Helper()
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment