Commit 2b70d1d3 authored by IanShaw027's avatar IanShaw027
Browse files

merge upstream main into fix/bug-cleanup-main

parents b37afd68 00c08c57
...@@ -54,6 +54,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -54,6 +54,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
Version: h.version, Version: h.version,
}) })
......
...@@ -28,7 +28,7 @@ type AnthropicRequest struct { ...@@ -28,7 +28,7 @@ type AnthropicRequest struct {
// AnthropicOutputConfig controls output generation parameters. // AnthropicOutputConfig controls output generation parameters.
type AnthropicOutputConfig struct { type AnthropicOutputConfig struct {
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high" Effort string `json:"effort,omitempty"` // "low" | "medium" | "high" | "max"
} }
// AnthropicThinking configures extended thinking in the Anthropic API. // AnthropicThinking configures extended thinking in the Anthropic API.
...@@ -167,7 +167,7 @@ type ResponsesRequest struct { ...@@ -167,7 +167,7 @@ type ResponsesRequest struct {
// ResponsesReasoning configures reasoning effort in the Responses API. // ResponsesReasoning configures reasoning effort in the Responses API.
type ResponsesReasoning struct { type ResponsesReasoning struct {
Effort string `json:"effort"` // "low" | "medium" | "high" Effort string `json:"effort"` // "low" | "medium" | "high" | "xhigh"
Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed" Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed"
} }
...@@ -345,7 +345,7 @@ type ChatCompletionsRequest struct { ...@@ -345,7 +345,7 @@ type ChatCompletionsRequest struct {
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"` StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
Tools []ChatTool `json:"tools,omitempty"` Tools []ChatTool `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"` ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" | "xhigh"
ServiceTier string `json:"service_tier,omitempty"` ServiceTier string `json:"service_tier,omitempty"`
Stop json.RawMessage `json:"stop,omitempty"` // string or []string Stop json.RawMessage `json:"stop,omitempty"` // string or []string
......
...@@ -61,7 +61,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er ...@@ -61,7 +61,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet). SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel) SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
// 设置模型路由配置 // 设置模型路由配置
if groupIn.ModelRouting != nil { if groupIn.ModelRouting != nil {
...@@ -127,7 +128,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er ...@@ -127,7 +128,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet). SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel) SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
if groupIn.DailyLimitUSD != nil { if groupIn.DailyLimitUSD != nil {
......
...@@ -378,6 +378,7 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account { ...@@ -378,6 +378,7 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account {
Platform: account.Platform, Platform: account.Platform,
Type: account.Type, Type: account.Type,
Concurrency: account.Concurrency, Concurrency: account.Concurrency,
LoadFactor: account.LoadFactor,
Priority: account.Priority, Priority: account.Priority,
RateMultiplier: account.RateMultiplier, RateMultiplier: account.RateMultiplier,
Status: account.Status, Status: account.Status,
......
...@@ -462,6 +462,28 @@ func TestAPIContracts(t *testing.T) { ...@@ -462,6 +462,28 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyTurnstileSiteKey: "site-key", service.SettingKeyTurnstileSiteKey: "site-key",
service.SettingKeyTurnstileSecretKey: "secret-key", service.SettingKeyTurnstileSecretKey: "secret-key",
service.SettingKeyOIDCConnectEnabled: "false",
service.SettingKeyOIDCConnectProviderName: "OIDC",
service.SettingKeyOIDCConnectClientID: "",
service.SettingKeyOIDCConnectIssuerURL: "",
service.SettingKeyOIDCConnectDiscoveryURL: "",
service.SettingKeyOIDCConnectAuthorizeURL: "",
service.SettingKeyOIDCConnectTokenURL: "",
service.SettingKeyOIDCConnectUserInfoURL: "",
service.SettingKeyOIDCConnectJWKSURL: "",
service.SettingKeyOIDCConnectScopes: "openid email profile",
service.SettingKeyOIDCConnectRedirectURL: "",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectUsePKCE: "false",
service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
service.SettingKeyOIDCConnectRequireEmailVerified: "false",
service.SettingKeyOIDCConnectUserInfoEmailPath: "",
service.SettingKeyOIDCConnectUserInfoIDPath: "",
service.SettingKeyOIDCConnectUserInfoUsernamePath: "",
service.SettingKeySiteName: "Sub2API", service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "", service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle", service.SettingKeySiteSubtitle: "Subtitle",
...@@ -503,10 +525,32 @@ func TestAPIContracts(t *testing.T) { ...@@ -503,10 +525,32 @@ func TestAPIContracts(t *testing.T) {
"turnstile_enabled": true, "turnstile_enabled": true,
"turnstile_site_key": "site-key", "turnstile_site_key": "site-key",
"turnstile_secret_key_configured": true, "turnstile_secret_key_configured": true,
"linuxdo_connect_enabled": false, "linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "", "linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false, "linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "", "linuxdo_connect_redirect_url": "",
"oidc_connect_enabled": false,
"oidc_connect_provider_name": "OIDC",
"oidc_connect_client_id": "",
"oidc_connect_client_secret_configured": false,
"oidc_connect_issuer_url": "",
"oidc_connect_discovery_url": "",
"oidc_connect_authorize_url": "",
"oidc_connect_token_url": "",
"oidc_connect_userinfo_url": "",
"oidc_connect_jwks_url": "",
"oidc_connect_scopes": "openid email profile",
"oidc_connect_redirect_url": "",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post",
"oidc_connect_use_pkce": false,
"oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120,
"oidc_connect_require_email_verified": false,
"oidc_connect_userinfo_email_path": "",
"oidc_connect_userinfo_id_path": "",
"oidc_connect_userinfo_username_path": "",
"ops_monitoring_enabled": false, "ops_monitoring_enabled": false,
"ops_realtime_monitoring_enabled": true, "ops_realtime_monitoring_enabled": true,
"ops_query_mode_default": "auto", "ops_query_mode_default": "auto",
......
...@@ -70,6 +70,14 @@ func RegisterAuthRoutes( ...@@ -70,6 +70,14 @@ func RegisterAuthRoutes(
}), }),
h.Auth.CompleteLinuxDoOAuthRegistration, h.Auth.CompleteLinuxDoOAuthRegistration,
) )
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteOIDCOAuthRegistration,
)
} }
// 公开设置(无需认证) // 公开设置(无需认证)
......
...@@ -152,10 +152,11 @@ type CreateGroupInput struct { ...@@ -152,10 +152,11 @@ type CreateGroupInput struct {
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string SupportedModelScopes []string
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool AllowMessagesDispatch bool
DefaultMappedModel string DefaultMappedModel string
RequireOAuthOnly bool RequireOAuthOnly bool
RequirePrivacySet bool RequirePrivacySet bool
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// 从指定分组复制账号(创建分组后在同一事务内绑定) // 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -186,10 +187,11 @@ type UpdateGroupInput struct { ...@@ -186,10 +187,11 @@ type UpdateGroupInput struct {
// 支持的模型系列(仅 antigravity 平台使用) // 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string SupportedModelScopes *[]string
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool AllowMessagesDispatch *bool
DefaultMappedModel *string DefaultMappedModel *string
RequireOAuthOnly *bool RequireOAuthOnly *bool
RequirePrivacySet *bool RequirePrivacySet *bool
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 CopyAccountsFromGroupIDs []int64
} }
...@@ -908,7 +910,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -908,7 +910,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequireOAuthOnly: input.RequireOAuthOnly, RequireOAuthOnly: input.RequireOAuthOnly,
RequirePrivacySet: input.RequirePrivacySet, RequirePrivacySet: input.RequirePrivacySet,
DefaultMappedModel: input.DefaultMappedModel, DefaultMappedModel: input.DefaultMappedModel,
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
} }
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
} }
...@@ -1135,6 +1139,10 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -1135,6 +1139,10 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.DefaultMappedModel != nil { if input.DefaultMappedModel != nil {
group.DefaultMappedModel = *input.DefaultMappedModel group.DefaultMappedModel = *input.DefaultMappedModel
} }
if input.MessagesDispatchModelConfig != nil {
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
}
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err return nil, err
......
...@@ -10,6 +10,11 @@ import ( ...@@ -10,6 +10,11 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func ptrString[T ~string](v T) *string {
s := string(v)
return &s
}
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub // groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
type groupRepoStubForAdmin struct { type groupRepoStubForAdmin struct {
created *Group // 记录 Create 调用的参数 created *Group // 记录 Create 调用的参数
...@@ -261,6 +266,116 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { ...@@ -261,6 +266,116 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K) require.Nil(t, repo.updated.ImagePrice4K)
} }
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "dispatch-group",
Description: "dispatch config",
Platform: PlatformOpenAI,
RateMultiplier: 1.0,
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: " gpt-5.4-high ",
SonnetMappedModel: " gpt-5.3-codex ",
HaikuMappedModel: " gpt-5.4-mini-medium ",
ExactModelMappings: map[string]string{
" claude-sonnet-4-5-20250929 ": " gpt-5.2-high ",
},
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Equal(t, OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4-5-20250929": "gpt-5.2",
},
}, repo.created.MessagesDispatchModelConfig)
}
func TestAdminService_UpdateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformOpenAI,
Status: StatusActive,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
MessagesDispatchModelConfig: &OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: " gpt-5.4-medium ",
ExactModelMappings: map[string]string{
" claude-haiku-4-5-20251001 ": " gpt-5.4-mini-high ",
},
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.4",
ExactModelMappings: map[string]string{
"claude-haiku-4-5-20251001": "gpt-5.4-mini",
},
}, repo.updated.MessagesDispatchModelConfig)
}
func TestAdminService_CreateGroup_ClearsMessagesDispatchFieldsForNonOpenAIPlatform(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "anthropic-group",
Description: "non-openai",
Platform: PlatformAnthropic,
RateMultiplier: 1.0,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.False(t, repo.created.AllowMessagesDispatch)
require.Empty(t, repo.created.DefaultMappedModel)
require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.created.MessagesDispatchModelConfig)
}
func TestAdminService_UpdateGroup_ClearsMessagesDispatchFieldsWhenPlatformChangesAwayFromOpenAI(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-openai-group",
Platform: PlatformOpenAI,
Status: StatusActive,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.3-codex",
},
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
Platform: PlatformAnthropic,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, PlatformAnthropic, repo.updated.Platform)
require.False(t, repo.updated.AllowMessagesDispatch)
require.Empty(t, repo.updated.DefaultMappedModel)
require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.updated.MessagesDispatchModelConfig)
}
func TestAdminService_ListGroups_WithSearch(t *testing.T) { func TestAdminService_ListGroups_WithSearch(t *testing.T) {
// 测试: // 测试:
// 1. search 参数正常传递到 repository 层 // 1. search 参数正常传递到 repository 层
......
...@@ -833,7 +833,8 @@ func randomHexString(byteLength int) (string, error) { ...@@ -833,7 +833,8 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool { func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email)) normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
} }
// GenerateToken 生成JWT access token // GenerateToken 生成JWT access token
......
...@@ -71,6 +71,9 @@ const ( ...@@ -71,6 +71,9 @@ const (
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。 // LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
// Setting keys // Setting keys
const ( const (
// 注册设置 // 注册设置
...@@ -105,6 +108,30 @@ const ( ...@@ -105,6 +108,30 @@ const (
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret" SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// Generic OIDC OAuth 登录设置
SettingKeyOIDCConnectEnabled = "oidc_connect_enabled"
SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name"
SettingKeyOIDCConnectClientID = "oidc_connect_client_id"
SettingKeyOIDCConnectClientSecret = "oidc_connect_client_secret"
SettingKeyOIDCConnectIssuerURL = "oidc_connect_issuer_url"
SettingKeyOIDCConnectDiscoveryURL = "oidc_connect_discovery_url"
SettingKeyOIDCConnectAuthorizeURL = "oidc_connect_authorize_url"
SettingKeyOIDCConnectTokenURL = "oidc_connect_token_url"
SettingKeyOIDCConnectUserInfoURL = "oidc_connect_userinfo_url"
SettingKeyOIDCConnectJWKSURL = "oidc_connect_jwks_url"
SettingKeyOIDCConnectScopes = "oidc_connect_scopes"
SettingKeyOIDCConnectRedirectURL = "oidc_connect_redirect_url"
SettingKeyOIDCConnectFrontendRedirectURL = "oidc_connect_frontend_redirect_url"
SettingKeyOIDCConnectTokenAuthMethod = "oidc_connect_token_auth_method"
SettingKeyOIDCConnectUsePKCE = "oidc_connect_use_pkce"
SettingKeyOIDCConnectValidateIDToken = "oidc_connect_validate_id_token"
SettingKeyOIDCConnectAllowedSigningAlgs = "oidc_connect_allowed_signing_algs"
SettingKeyOIDCConnectClockSkewSeconds = "oidc_connect_clock_skew_seconds"
SettingKeyOIDCConnectRequireEmailVerified = "oidc_connect_require_email_verified"
SettingKeyOIDCConnectUserInfoEmailPath = "oidc_connect_userinfo_email_path"
SettingKeyOIDCConnectUserInfoIDPath = "oidc_connect_userinfo_id_path"
SettingKeyOIDCConnectUserInfoUsernamePath = "oidc_connect_userinfo_username_path"
// OEM设置 // OEM设置
SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
......
...@@ -3,8 +3,12 @@ package service ...@@ -3,8 +3,12 @@ package service
import ( import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/domain"
) )
type OpenAIMessagesDispatchModelConfig = domain.OpenAIMessagesDispatchModelConfig
type Group struct { type Group struct {
ID int64 ID int64
Name string Name string
...@@ -49,10 +53,11 @@ type Group struct { ...@@ -49,10 +53,11 @@ type Group struct {
SortOrder int SortOrder int
// OpenAI Messages 调度配置(仅 openai 平台使用) // OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool AllowMessagesDispatch bool
RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini) RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini)
RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini)
DefaultMappedModel string DefaultMappedModel string
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
......
package service
import (
"bytes"
"fmt"
"strings"
"text/template"
)
type forcedCodexInstructionsTemplateData struct {
ExistingInstructions string
OriginalModel string
NormalizedModel string
BillingModel string
UpstreamModel string
}
func applyForcedCodexInstructionsTemplate(
reqBody map[string]any,
templateText string,
data forcedCodexInstructionsTemplateData,
) (bool, error) {
rendered, err := renderForcedCodexInstructionsTemplate(templateText, data)
if err != nil {
return false, err
}
if rendered == "" {
return false, nil
}
existing, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existing) == rendered {
return false, nil
}
reqBody["instructions"] = rendered
return true, nil
}
func renderForcedCodexInstructionsTemplate(
templateText string,
data forcedCodexInstructionsTemplateData,
) (string, error) {
tmpl, err := template.New("forced_codex_instructions").Option("missingkey=zero").Parse(templateText)
if err != nil {
return "", fmt.Errorf("parse forced codex instructions template: %w", err)
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", fmt.Errorf("render forced codex instructions template: %w", err)
}
return strings.TrimSpace(buf.String()), nil
}
...@@ -6,9 +6,12 @@ import ( ...@@ -6,9 +6,12 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"path/filepath"
"strings" "strings"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -127,3 +130,101 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T ...@@ -127,3 +130,101 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T
t.Logf("upstream body: %s", string(upstream.lastBody)) t.Logf("upstream body: %s", string(upstream.lastBody))
t.Logf("response body: %s", rec.Body.String()) t.Logf("response body: %s", rec.Body.String())
} }
func TestForwardAsAnthropic_ForcedCodexInstructionsTemplatePrependsRenderedInstructions(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
templateDir := t.TempDir()
templatePath := filepath.Join(templateDir, "codex-instructions.md.tmpl")
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"system":"client-system","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_forced"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{
ForcedCodexInstructionsTemplateFile: templatePath,
ForcedCodexInstructionsTemplate: "server-prefix\n\n{{ .ExistingInstructions }}",
}},
httpUpstream: upstream,
}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "server-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
}
func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateContent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"system":"client-system","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_forced_cached"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{
ForcedCodexInstructionsTemplateFile: "/path/that/should/not/be/read.tmpl",
ForcedCodexInstructionsTemplate: "cached-prefix\n\n{{ .ExistingInstructions }}",
}},
httpUpstream: upstream,
}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
}
...@@ -86,6 +86,24 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( ...@@ -86,6 +86,24 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, fmt.Errorf("unmarshal for codex transform: %w", err) return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
} }
codexResult := applyCodexOAuthTransform(reqBody, false, false) codexResult := applyCodexOAuthTransform(reqBody, false, false)
forcedTemplateText := ""
if s.cfg != nil {
forcedTemplateText = s.cfg.Gateway.ForcedCodexInstructionsTemplate
}
templateUpstreamModel := upstreamModel
if codexResult.NormalizedModel != "" {
templateUpstreamModel = codexResult.NormalizedModel
}
existingInstructions, _ := reqBody["instructions"].(string)
if _, err := applyForcedCodexInstructionsTemplate(reqBody, forcedTemplateText, forcedCodexInstructionsTemplateData{
ExistingInstructions: strings.TrimSpace(existingInstructions),
OriginalModel: originalModel,
NormalizedModel: normalizedModel,
BillingModel: billingModel,
UpstreamModel: templateUpstreamModel,
}); err != nil {
return nil, err
}
if codexResult.NormalizedModel != "" { if codexResult.NormalizedModel != "" {
upstreamModel = codexResult.NormalizedModel upstreamModel = codexResult.NormalizedModel
} }
......
package service
import "strings"
const (
defaultOpenAIMessagesDispatchOpusMappedModel = "gpt-5.4"
defaultOpenAIMessagesDispatchSonnetMappedModel = "gpt-5.3-codex"
defaultOpenAIMessagesDispatchHaikuMappedModel = "gpt-5.4-mini"
)
func normalizeOpenAIMessagesDispatchMappedModel(model string) string {
model = NormalizeOpenAICompatRequestedModel(strings.TrimSpace(model))
return strings.TrimSpace(model)
}
func normalizeOpenAIMessagesDispatchModelConfig(cfg OpenAIMessagesDispatchModelConfig) OpenAIMessagesDispatchModelConfig {
out := OpenAIMessagesDispatchModelConfig{
OpusMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.OpusMappedModel),
SonnetMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.SonnetMappedModel),
HaikuMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.HaikuMappedModel),
}
if len(cfg.ExactModelMappings) > 0 {
out.ExactModelMappings = make(map[string]string, len(cfg.ExactModelMappings))
for requestedModel, mappedModel := range cfg.ExactModelMappings {
requestedModel = strings.TrimSpace(requestedModel)
mappedModel = normalizeOpenAIMessagesDispatchMappedModel(mappedModel)
if requestedModel == "" || mappedModel == "" {
continue
}
out.ExactModelMappings[requestedModel] = mappedModel
}
if len(out.ExactModelMappings) == 0 {
out.ExactModelMappings = nil
}
}
return out
}
func claudeMessagesDispatchFamily(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
if !strings.HasPrefix(normalized, "claude") {
return ""
}
switch {
case strings.Contains(normalized, "opus"):
return "opus"
case strings.Contains(normalized, "sonnet"):
return "sonnet"
case strings.Contains(normalized, "haiku"):
return "haiku"
default:
return ""
}
}
func (g *Group) ResolveMessagesDispatchModel(requestedModel string) string {
if g == nil {
return ""
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return ""
}
cfg := normalizeOpenAIMessagesDispatchModelConfig(g.MessagesDispatchModelConfig)
if mappedModel := strings.TrimSpace(cfg.ExactModelMappings[requestedModel]); mappedModel != "" {
return mappedModel
}
switch claudeMessagesDispatchFamily(requestedModel) {
case "opus":
if mappedModel := strings.TrimSpace(cfg.OpusMappedModel); mappedModel != "" {
return mappedModel
}
return defaultOpenAIMessagesDispatchOpusMappedModel
case "sonnet":
if mappedModel := strings.TrimSpace(cfg.SonnetMappedModel); mappedModel != "" {
return mappedModel
}
return defaultOpenAIMessagesDispatchSonnetMappedModel
case "haiku":
if mappedModel := strings.TrimSpace(cfg.HaikuMappedModel); mappedModel != "" {
return mappedModel
}
return defaultOpenAIMessagesDispatchHaikuMappedModel
default:
return ""
}
}
func sanitizeGroupMessagesDispatchFields(g *Group) {
if g == nil || g.Platform == PlatformOpenAI {
return
}
g.AllowMessagesDispatch = false
g.DefaultMappedModel = ""
g.MessagesDispatchModelConfig = OpenAIMessagesDispatchModelConfig{}
}
package service
import "testing"
import "github.com/stretchr/testify/require"
func TestNormalizeOpenAIMessagesDispatchModelConfig(t *testing.T) {
t.Parallel()
cfg := normalizeOpenAIMessagesDispatchModelConfig(OpenAIMessagesDispatchModelConfig{
OpusMappedModel: " gpt-5.4-high ",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: " gpt-5.4-mini-medium ",
ExactModelMappings: map[string]string{
" claude-sonnet-4-5-20250929 ": " gpt-5.2-high ",
"": "gpt-5.4",
"claude-opus-4-6": " ",
},
})
require.Equal(t, "gpt-5.4", cfg.OpusMappedModel)
require.Equal(t, "gpt-5.3-codex", cfg.SonnetMappedModel)
require.Equal(t, "gpt-5.4-mini", cfg.HaikuMappedModel)
require.Equal(t, map[string]string{
"claude-sonnet-4-5-20250929": "gpt-5.2",
}, cfg.ExactModelMappings)
}
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
var ErrOpsDisabled = infraerrors.NotFound("OPS_DISABLED", "Ops monitoring is disabled") var ErrOpsDisabled = infraerrors.NotFound("OPS_DISABLED", "Ops monitoring is disabled")
const ( const (
opsMaxStoredRequestBodyBytes = 10 * 1024 opsMaxStoredRequestBodyBytes = 256 * 1024
opsMaxStoredErrorBodyBytes = 20 * 1024 opsMaxStoredErrorBodyBytes = 20 * 1024
) )
......
...@@ -17,6 +17,7 @@ import ( ...@@ -17,6 +17,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/imroc/req/v3"
"golang.org/x/sync/singleflight" "golang.org/x/sync/singleflight"
) )
...@@ -167,6 +168,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -167,6 +168,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyCustomEndpoints, SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled, SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled, SettingKeyBackendModeEnabled,
SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
...@@ -180,6 +183,19 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -180,6 +183,19 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
} else { } else {
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
} }
oidcEnabled := false
if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok {
oidcEnabled = raw == "true"
} else {
oidcEnabled = s.cfg != nil && s.cfg.OIDC.Enabled
}
oidcProviderName := strings.TrimSpace(settings[SettingKeyOIDCConnectProviderName])
if oidcProviderName == "" && s.cfg != nil {
oidcProviderName = strings.TrimSpace(s.cfg.OIDC.ProviderName)
}
if oidcProviderName == "" {
oidcProviderName = "OIDC"
}
// Password reset requires email verification to be enabled // Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
...@@ -218,6 +234,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -218,6 +234,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomEndpoints: settings[SettingKeyCustomEndpoints], CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled, LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
OIDCOAuthProviderName: oidcProviderName,
}, nil }, nil
} }
...@@ -267,6 +285,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -267,6 +285,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints json.RawMessage `json:"custom_endpoints"` CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
}{ }{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
...@@ -294,6 +314,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -294,6 +314,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
Version: s.version, Version: s.version,
}, nil }, nil
} }
...@@ -346,8 +368,8 @@ func safeRawJSONArray(raw string) json.RawMessage { ...@@ -346,8 +368,8 @@ func safeRawJSONArray(raw string) json.RawMessage {
return json.RawMessage("[]") return json.RawMessage("[]")
} }
// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url // GetFrameSrcOrigins returns deduplicated http(s) origins from home_content URL,
// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. // purchase_subscription_url, and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection.
func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) {
settings, err := s.GetPublicSettings(ctx) settings, err := s.GetPublicSettings(ctx)
if err != nil { if err != nil {
...@@ -366,6 +388,9 @@ func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, erro ...@@ -366,6 +388,9 @@ func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, erro
} }
} }
// home content URL (when home_content is set to a URL for iframe embedding)
addOrigin(settings.HomeContent)
// purchase subscription URL // purchase subscription URL
if settings.PurchaseSubscriptionEnabled { if settings.PurchaseSubscriptionEnabled {
addOrigin(settings.PurchaseSubscriptionURL) addOrigin(settings.PurchaseSubscriptionURL)
...@@ -473,6 +498,32 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -473,6 +498,32 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret
} }
// Generic OIDC OAuth 登录
updates[SettingKeyOIDCConnectEnabled] = strconv.FormatBool(settings.OIDCConnectEnabled)
updates[SettingKeyOIDCConnectProviderName] = settings.OIDCConnectProviderName
updates[SettingKeyOIDCConnectClientID] = settings.OIDCConnectClientID
updates[SettingKeyOIDCConnectIssuerURL] = settings.OIDCConnectIssuerURL
updates[SettingKeyOIDCConnectDiscoveryURL] = settings.OIDCConnectDiscoveryURL
updates[SettingKeyOIDCConnectAuthorizeURL] = settings.OIDCConnectAuthorizeURL
updates[SettingKeyOIDCConnectTokenURL] = settings.OIDCConnectTokenURL
updates[SettingKeyOIDCConnectUserInfoURL] = settings.OIDCConnectUserInfoURL
updates[SettingKeyOIDCConnectJWKSURL] = settings.OIDCConnectJWKSURL
updates[SettingKeyOIDCConnectScopes] = settings.OIDCConnectScopes
updates[SettingKeyOIDCConnectRedirectURL] = settings.OIDCConnectRedirectURL
updates[SettingKeyOIDCConnectFrontendRedirectURL] = settings.OIDCConnectFrontendRedirectURL
updates[SettingKeyOIDCConnectTokenAuthMethod] = settings.OIDCConnectTokenAuthMethod
updates[SettingKeyOIDCConnectUsePKCE] = strconv.FormatBool(settings.OIDCConnectUsePKCE)
updates[SettingKeyOIDCConnectValidateIDToken] = strconv.FormatBool(settings.OIDCConnectValidateIDToken)
updates[SettingKeyOIDCConnectAllowedSigningAlgs] = settings.OIDCConnectAllowedSigningAlgs
updates[SettingKeyOIDCConnectClockSkewSeconds] = strconv.Itoa(settings.OIDCConnectClockSkewSeconds)
updates[SettingKeyOIDCConnectRequireEmailVerified] = strconv.FormatBool(settings.OIDCConnectRequireEmailVerified)
updates[SettingKeyOIDCConnectUserInfoEmailPath] = settings.OIDCConnectUserInfoEmailPath
updates[SettingKeyOIDCConnectUserInfoIDPath] = settings.OIDCConnectUserInfoIDPath
updates[SettingKeyOIDCConnectUserInfoUsernamePath] = settings.OIDCConnectUserInfoUsernamePath
if settings.OIDCConnectClientSecret != "" {
updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret
}
// OEM设置 // OEM设置
updates[SettingKeySiteName] = settings.SiteName updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo updates[SettingKeySiteLogo] = settings.SiteLogo
...@@ -851,6 +902,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -851,6 +902,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyTablePageSizeOptions: "[10,20,50,100]", SettingKeyTablePageSizeOptions: "[10,20,50,100]",
SettingKeyCustomMenuItems: "[]", SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]", SettingKeyCustomEndpoints: "[]",
SettingKeyOIDCConnectEnabled: "false",
SettingKeyOIDCConnectProviderName: "OIDC",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]", SettingKeyDefaultSubscriptions: "[]",
...@@ -980,6 +1033,138 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -980,6 +1033,138 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} }
result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != "" result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != ""
// Generic OIDC 设置:
// - 兼容 config.yaml/env
// - 支持后台系统设置覆盖并持久化(存储于 DB)
oidcBase := config.OIDCConnectConfig{}
if s.cfg != nil {
oidcBase = s.cfg.OIDC
}
if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok {
result.OIDCConnectEnabled = raw == "true"
} else {
result.OIDCConnectEnabled = oidcBase.Enabled
}
if v, ok := settings[SettingKeyOIDCConnectProviderName]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectProviderName = strings.TrimSpace(v)
} else {
result.OIDCConnectProviderName = strings.TrimSpace(oidcBase.ProviderName)
}
if result.OIDCConnectProviderName == "" {
result.OIDCConnectProviderName = "OIDC"
}
if v, ok := settings[SettingKeyOIDCConnectClientID]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectClientID = strings.TrimSpace(v)
} else {
result.OIDCConnectClientID = strings.TrimSpace(oidcBase.ClientID)
}
if v, ok := settings[SettingKeyOIDCConnectIssuerURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectIssuerURL = strings.TrimSpace(v)
} else {
result.OIDCConnectIssuerURL = strings.TrimSpace(oidcBase.IssuerURL)
}
if v, ok := settings[SettingKeyOIDCConnectDiscoveryURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectDiscoveryURL = strings.TrimSpace(v)
} else {
result.OIDCConnectDiscoveryURL = strings.TrimSpace(oidcBase.DiscoveryURL)
}
if v, ok := settings[SettingKeyOIDCConnectAuthorizeURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAuthorizeURL = strings.TrimSpace(v)
} else {
result.OIDCConnectAuthorizeURL = strings.TrimSpace(oidcBase.AuthorizeURL)
}
if v, ok := settings[SettingKeyOIDCConnectTokenURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectTokenURL = strings.TrimSpace(v)
} else {
result.OIDCConnectTokenURL = strings.TrimSpace(oidcBase.TokenURL)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectUserInfoURL = strings.TrimSpace(v)
} else {
result.OIDCConnectUserInfoURL = strings.TrimSpace(oidcBase.UserInfoURL)
}
if v, ok := settings[SettingKeyOIDCConnectJWKSURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectJWKSURL = strings.TrimSpace(v)
} else {
result.OIDCConnectJWKSURL = strings.TrimSpace(oidcBase.JWKSURL)
}
if v, ok := settings[SettingKeyOIDCConnectScopes]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectScopes = strings.TrimSpace(v)
} else {
result.OIDCConnectScopes = strings.TrimSpace(oidcBase.Scopes)
}
if v, ok := settings[SettingKeyOIDCConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectRedirectURL = strings.TrimSpace(v)
} else {
result.OIDCConnectRedirectURL = strings.TrimSpace(oidcBase.RedirectURL)
}
if v, ok := settings[SettingKeyOIDCConnectFrontendRedirectURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectFrontendRedirectURL = strings.TrimSpace(v)
} else {
result.OIDCConnectFrontendRedirectURL = strings.TrimSpace(oidcBase.FrontendRedirectURL)
}
if v, ok := settings[SettingKeyOIDCConnectTokenAuthMethod]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(v))
} else {
result.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(oidcBase.TokenAuthMethod))
}
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
result.OIDCConnectUsePKCE = raw == "true"
} else {
result.OIDCConnectUsePKCE = oidcBase.UsePKCE
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
result.OIDCConnectValidateIDToken = raw == "true"
} else {
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
}
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
} else {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(oidcBase.AllowedSigningAlgs)
}
clockSkewSet := false
if raw, ok := settings[SettingKeyOIDCConnectClockSkewSeconds]; ok && strings.TrimSpace(raw) != "" {
if parsed, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil {
result.OIDCConnectClockSkewSeconds = parsed
clockSkewSet = true
}
}
if !clockSkewSet {
result.OIDCConnectClockSkewSeconds = oidcBase.ClockSkewSeconds
}
if !clockSkewSet && result.OIDCConnectClockSkewSeconds == 0 {
result.OIDCConnectClockSkewSeconds = 120
}
if raw, ok := settings[SettingKeyOIDCConnectRequireEmailVerified]; ok {
result.OIDCConnectRequireEmailVerified = raw == "true"
} else {
result.OIDCConnectRequireEmailVerified = oidcBase.RequireEmailVerified
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoEmailPath]; ok {
result.OIDCConnectUserInfoEmailPath = strings.TrimSpace(v)
} else {
result.OIDCConnectUserInfoEmailPath = strings.TrimSpace(oidcBase.UserInfoEmailPath)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoIDPath]; ok {
result.OIDCConnectUserInfoIDPath = strings.TrimSpace(v)
} else {
result.OIDCConnectUserInfoIDPath = strings.TrimSpace(oidcBase.UserInfoIDPath)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoUsernamePath]; ok {
result.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(v)
} else {
result.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(oidcBase.UserInfoUsernamePath)
}
result.OIDCConnectClientSecret = strings.TrimSpace(settings[SettingKeyOIDCConnectClientSecret])
if result.OIDCConnectClientSecret == "" {
result.OIDCConnectClientSecret = strings.TrimSpace(oidcBase.ClientSecret)
}
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
// Model fallback settings // Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022") result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
...@@ -1396,6 +1581,282 @@ func (s *SettingService) SetOverloadCooldownSettings(ctx context.Context, settin ...@@ -1396,6 +1581,282 @@ func (s *SettingService) SetOverloadCooldownSettings(ctx context.Context, settin
return s.settingRepo.Set(ctx, SettingKeyOverloadCooldownSettings, string(data)) return s.settingRepo.Set(ctx, SettingKeyOverloadCooldownSettings, string(data))
} }
// GetOIDCConnectOAuthConfig 返回用于登录的“最终生效” OIDC 配置。
//
// 优先级:
// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值
// - 否则回退到 config.yaml/env 的值
func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.OIDCConnectConfig, error) {
if s == nil || s.cfg == nil {
return config.OIDCConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
effective := s.cfg.OIDC
keys := []string{
SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName,
SettingKeyOIDCConnectClientID,
SettingKeyOIDCConnectClientSecret,
SettingKeyOIDCConnectIssuerURL,
SettingKeyOIDCConnectDiscoveryURL,
SettingKeyOIDCConnectAuthorizeURL,
SettingKeyOIDCConnectTokenURL,
SettingKeyOIDCConnectUserInfoURL,
SettingKeyOIDCConnectJWKSURL,
SettingKeyOIDCConnectScopes,
SettingKeyOIDCConnectRedirectURL,
SettingKeyOIDCConnectFrontendRedirectURL,
SettingKeyOIDCConnectTokenAuthMethod,
SettingKeyOIDCConnectUsePKCE,
SettingKeyOIDCConnectValidateIDToken,
SettingKeyOIDCConnectAllowedSigningAlgs,
SettingKeyOIDCConnectClockSkewSeconds,
SettingKeyOIDCConnectRequireEmailVerified,
SettingKeyOIDCConnectUserInfoEmailPath,
SettingKeyOIDCConnectUserInfoIDPath,
SettingKeyOIDCConnectUserInfoUsernamePath,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return config.OIDCConnectConfig{}, fmt.Errorf("get oidc connect settings: %w", err)
}
if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok {
effective.Enabled = raw == "true"
}
if v, ok := settings[SettingKeyOIDCConnectProviderName]; ok && strings.TrimSpace(v) != "" {
effective.ProviderName = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectClientID]; ok && strings.TrimSpace(v) != "" {
effective.ClientID = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectClientSecret]; ok && strings.TrimSpace(v) != "" {
effective.ClientSecret = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectIssuerURL]; ok && strings.TrimSpace(v) != "" {
effective.IssuerURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectDiscoveryURL]; ok && strings.TrimSpace(v) != "" {
effective.DiscoveryURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectAuthorizeURL]; ok && strings.TrimSpace(v) != "" {
effective.AuthorizeURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectTokenURL]; ok && strings.TrimSpace(v) != "" {
effective.TokenURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoURL]; ok && strings.TrimSpace(v) != "" {
effective.UserInfoURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectJWKSURL]; ok && strings.TrimSpace(v) != "" {
effective.JWKSURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectScopes]; ok && strings.TrimSpace(v) != "" {
effective.Scopes = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.RedirectURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectFrontendRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.FrontendRedirectURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectTokenAuthMethod]; ok && strings.TrimSpace(v) != "" {
effective.TokenAuthMethod = strings.ToLower(strings.TrimSpace(v))
}
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
effective.UsePKCE = raw == "true"
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true"
}
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v)
}
if raw, ok := settings[SettingKeyOIDCConnectClockSkewSeconds]; ok && strings.TrimSpace(raw) != "" {
if parsed, parseErr := strconv.Atoi(strings.TrimSpace(raw)); parseErr == nil {
effective.ClockSkewSeconds = parsed
}
}
if raw, ok := settings[SettingKeyOIDCConnectRequireEmailVerified]; ok {
effective.RequireEmailVerified = raw == "true"
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoEmailPath]; ok {
effective.UserInfoEmailPath = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoIDPath]; ok {
effective.UserInfoIDPath = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoUsernamePath]; ok {
effective.UserInfoUsernamePath = strings.TrimSpace(v)
}
if !effective.Enabled {
return config.OIDCConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
if strings.TrimSpace(effective.ProviderName) == "" {
effective.ProviderName = "OIDC"
}
if strings.TrimSpace(effective.ClientID) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
}
if strings.TrimSpace(effective.IssuerURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth issuer url not configured")
}
if strings.TrimSpace(effective.RedirectURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")
}
if strings.TrimSpace(effective.FrontendRedirectURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url not configured")
}
if !scopesContainOpenID(effective.Scopes) {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth scopes must contain openid")
}
if effective.ClockSkewSeconds < 0 || effective.ClockSkewSeconds > 600 {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth clock skew must be between 0 and 600")
}
if err := config.ValidateAbsoluteHTTPURL(effective.IssuerURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth issuer url invalid")
}
discoveryURL := strings.TrimSpace(effective.DiscoveryURL)
if discoveryURL == "" {
discoveryURL = oidcDefaultDiscoveryURL(effective.IssuerURL)
effective.DiscoveryURL = discoveryURL
}
if discoveryURL != "" {
if err := config.ValidateAbsoluteHTTPURL(discoveryURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth discovery url invalid")
}
}
needsDiscovery := strings.TrimSpace(effective.AuthorizeURL) == "" ||
strings.TrimSpace(effective.TokenURL) == "" ||
(effective.ValidateIDToken && strings.TrimSpace(effective.JWKSURL) == "")
if needsDiscovery && discoveryURL != "" {
metadata, resolveErr := oidcResolveProviderMetadata(ctx, discoveryURL)
if resolveErr != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth discovery resolve failed").WithCause(resolveErr)
}
if strings.TrimSpace(effective.AuthorizeURL) == "" {
effective.AuthorizeURL = strings.TrimSpace(metadata.AuthorizationEndpoint)
}
if strings.TrimSpace(effective.TokenURL) == "" {
effective.TokenURL = strings.TrimSpace(metadata.TokenEndpoint)
}
if strings.TrimSpace(effective.UserInfoURL) == "" {
effective.UserInfoURL = strings.TrimSpace(metadata.UserInfoEndpoint)
}
if strings.TrimSpace(effective.JWKSURL) == "" {
effective.JWKSURL = strings.TrimSpace(metadata.JWKSURI)
}
}
if strings.TrimSpace(effective.AuthorizeURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url not configured")
}
if strings.TrimSpace(effective.TokenURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url not configured")
}
if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url invalid")
}
if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url invalid")
}
if v := strings.TrimSpace(effective.UserInfoURL); v != "" {
if err := config.ValidateAbsoluteHTTPURL(v); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url invalid")
}
}
if effective.ValidateIDToken {
if strings.TrimSpace(effective.JWKSURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth jwks url not configured")
}
if strings.TrimSpace(effective.AllowedSigningAlgs) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth signing algs not configured")
}
}
if v := strings.TrimSpace(effective.JWKSURL); v != "" {
if err := config.ValidateAbsoluteHTTPURL(v); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth jwks url invalid")
}
}
if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url invalid")
}
if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
}
method := strings.ToLower(strings.TrimSpace(effective.TokenAuthMethod))
switch method {
case "", "client_secret_post", "client_secret_basic":
if strings.TrimSpace(effective.ClientSecret) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
if !effective.UsePKCE {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
}
default:
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
return effective, nil
}
func scopesContainOpenID(scopes string) bool {
for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) {
if scope == "openid" {
return true
}
}
return false
}
type oidcProviderMetadata struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint"`
JWKSURI string `json:"jwks_uri"`
}
func oidcDefaultDiscoveryURL(issuerURL string) string {
issuerURL = strings.TrimSpace(issuerURL)
if issuerURL == "" {
return ""
}
return strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration"
}
func oidcResolveProviderMetadata(ctx context.Context, discoveryURL string) (*oidcProviderMetadata, error) {
discoveryURL = strings.TrimSpace(discoveryURL)
if discoveryURL == "" {
return nil, fmt.Errorf("discovery url is empty")
}
resp, err := req.C().
SetTimeout(15*time.Second).
R().
SetContext(ctx).
SetHeader("Accept", "application/json").
Get(discoveryURL)
if err != nil {
return nil, fmt.Errorf("request discovery document: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("discovery request failed: status=%d", resp.StatusCode)
}
metadata := &oidcProviderMetadata{}
if err := json.Unmarshal(resp.Bytes(), metadata); err != nil {
return nil, fmt.Errorf("parse discovery document: %w", err)
}
return metadata, nil
}
// GetStreamTimeoutSettings 获取流超时处理配置 // GetStreamTimeoutSettings 获取流超时处理配置
func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) { func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings) value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings)
......
//go:build unit
package service
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type settingOIDCRepoStub struct {
values map[string]string
}
func (s *settingOIDCRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *settingOIDCRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *settingOIDCRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *settingOIDCRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *settingOIDCRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *settingOIDCRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *settingOIDCRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestGetOIDCConnectOAuthConfig_ResolvesEndpointsFromIssuerDiscovery(t *testing.T) {
var discoveryHits int
var baseURL string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/issuer/.well-known/openid-configuration" {
http.NotFound(w, r)
return
}
discoveryHits++
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(fmt.Sprintf(`{
"authorization_endpoint":"%s/issuer/protocol/openid-connect/auth",
"token_endpoint":"%s/issuer/protocol/openid-connect/token",
"userinfo_endpoint":"%s/issuer/protocol/openid-connect/userinfo",
"jwks_uri":"%s/issuer/protocol/openid-connect/certs"
}`, baseURL, baseURL, baseURL, baseURL)))
}))
defer srv.Close()
baseURL = srv.URL
cfg := &config.Config{
OIDC: config.OIDCConnectConfig{
Enabled: true,
ProviderName: "OIDC",
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: srv.URL + "/issuer",
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
Scopes: "openid email profile",
TokenAuthMethod: "client_secret_post",
ValidateIDToken: true,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
},
}
repo := &settingOIDCRepoStub{values: map[string]string{}}
svc := NewSettingService(repo, cfg)
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
require.NoError(t, err)
require.Equal(t, 1, discoveryHits)
require.Equal(t, srv.URL+"/issuer/.well-known/openid-configuration", got.DiscoveryURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/auth", got.AuthorizeURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/token", got.TokenURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/userinfo", got.UserInfoURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/certs", got.JWKSURL)
}
...@@ -31,6 +31,31 @@ type SystemSettings struct { ...@@ -31,6 +31,31 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool LinuxDoConnectClientSecretConfigured bool
LinuxDoConnectRedirectURL string LinuxDoConnectRedirectURL string
// Generic OIDC OAuth 登录
OIDCConnectEnabled bool
OIDCConnectProviderName string
OIDCConnectClientID string
OIDCConnectClientSecret string
OIDCConnectClientSecretConfigured bool
OIDCConnectIssuerURL string
OIDCConnectDiscoveryURL string
OIDCConnectAuthorizeURL string
OIDCConnectTokenURL string
OIDCConnectUserInfoURL string
OIDCConnectJWKSURL string
OIDCConnectScopes string
OIDCConnectRedirectURL string
OIDCConnectFrontendRedirectURL string
OIDCConnectTokenAuthMethod string
OIDCConnectUsePKCE bool
OIDCConnectValidateIDToken bool
OIDCConnectAllowedSigningAlgs string
OIDCConnectClockSkewSeconds int
OIDCConnectRequireEmailVerified bool
OIDCConnectUserInfoEmailPath string
OIDCConnectUserInfoIDPath string
OIDCConnectUserInfoUsernamePath string
SiteName string SiteName string
SiteLogo string SiteLogo string
SiteSubtitle string SiteSubtitle string
...@@ -114,9 +139,11 @@ type PublicSettings struct { ...@@ -114,9 +139,11 @@ type PublicSettings struct {
CustomMenuItems string // JSON array of custom menu items CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool LinuxDoOAuthEnabled bool
BackendModeEnabled bool BackendModeEnabled bool
Version string OIDCOAuthEnabled bool
OIDCOAuthProviderName string
Version string
} }
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment