Commit 2b70d1d3 authored by IanShaw027's avatar IanShaw027
Browse files

merge upstream main into fix/bug-cleanup-main

parents b37afd68 00c08c57
......@@ -54,6 +54,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
BackendModeEnabled: settings.BackendModeEnabled,
Version: h.version,
})
......
......@@ -28,7 +28,7 @@ type AnthropicRequest struct {
// AnthropicOutputConfig controls output generation parameters.
type AnthropicOutputConfig struct {
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high"
Effort string `json:"effort,omitempty"` // "low" | "medium" | "high" | "max"
}
// AnthropicThinking configures extended thinking in the Anthropic API.
......@@ -167,7 +167,7 @@ type ResponsesRequest struct {
// ResponsesReasoning configures reasoning effort in the Responses API.
type ResponsesReasoning struct {
Effort string `json:"effort"` // "low" | "medium" | "high"
Effort string `json:"effort"` // "low" | "medium" | "high" | "xhigh"
Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed"
}
......@@ -345,7 +345,7 @@ type ChatCompletionsRequest struct {
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
Tools []ChatTool `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" | "xhigh"
ServiceTier string `json:"service_tier,omitempty"`
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
......
......@@ -61,7 +61,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel)
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
......@@ -127,7 +128,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
SetDefaultMappedModel(groupIn.DefaultMappedModel)
SetDefaultMappedModel(groupIn.DefaultMappedModel).
SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig)
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
if groupIn.DailyLimitUSD != nil {
......
......@@ -378,6 +378,7 @@ func buildSchedulerMetadataAccount(account service.Account) service.Account {
Platform: account.Platform,
Type: account.Type,
Concurrency: account.Concurrency,
LoadFactor: account.LoadFactor,
Priority: account.Priority,
RateMultiplier: account.RateMultiplier,
Status: account.Status,
......
......@@ -462,6 +462,28 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyTurnstileSiteKey: "site-key",
service.SettingKeyTurnstileSecretKey: "secret-key",
service.SettingKeyOIDCConnectEnabled: "false",
service.SettingKeyOIDCConnectProviderName: "OIDC",
service.SettingKeyOIDCConnectClientID: "",
service.SettingKeyOIDCConnectIssuerURL: "",
service.SettingKeyOIDCConnectDiscoveryURL: "",
service.SettingKeyOIDCConnectAuthorizeURL: "",
service.SettingKeyOIDCConnectTokenURL: "",
service.SettingKeyOIDCConnectUserInfoURL: "",
service.SettingKeyOIDCConnectJWKSURL: "",
service.SettingKeyOIDCConnectScopes: "openid email profile",
service.SettingKeyOIDCConnectRedirectURL: "",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectUsePKCE: "false",
service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
service.SettingKeyOIDCConnectRequireEmailVerified: "false",
service.SettingKeyOIDCConnectUserInfoEmailPath: "",
service.SettingKeyOIDCConnectUserInfoIDPath: "",
service.SettingKeyOIDCConnectUserInfoUsernamePath: "",
service.SettingKeySiteName: "Sub2API",
service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle",
......@@ -503,10 +525,32 @@ func TestAPIContracts(t *testing.T) {
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
"turnstile_secret_key_configured": true,
"linuxdo_connect_enabled": false,
"linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"oidc_connect_enabled": false,
"oidc_connect_provider_name": "OIDC",
"oidc_connect_client_id": "",
"oidc_connect_client_secret_configured": false,
"oidc_connect_issuer_url": "",
"oidc_connect_discovery_url": "",
"oidc_connect_authorize_url": "",
"oidc_connect_token_url": "",
"oidc_connect_userinfo_url": "",
"oidc_connect_jwks_url": "",
"oidc_connect_scopes": "openid email profile",
"oidc_connect_redirect_url": "",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post",
"oidc_connect_use_pkce": false,
"oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120,
"oidc_connect_require_email_verified": false,
"oidc_connect_userinfo_email_path": "",
"oidc_connect_userinfo_id_path": "",
"oidc_connect_userinfo_username_path": "",
"ops_monitoring_enabled": false,
"ops_realtime_monitoring_enabled": true,
"ops_query_mode_default": "auto",
......
......@@ -70,6 +70,14 @@ func RegisterAuthRoutes(
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteOIDCOAuthRegistration,
)
}
// 公开设置(无需认证)
......
......@@ -152,10 +152,11 @@ type CreateGroupInput struct {
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool
DefaultMappedModel string
RequireOAuthOnly bool
RequirePrivacySet bool
AllowMessagesDispatch bool
DefaultMappedModel string
RequireOAuthOnly bool
RequirePrivacySet bool
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
......@@ -186,10 +187,11 @@ type UpdateGroupInput struct {
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch *bool
DefaultMappedModel *string
RequireOAuthOnly *bool
RequirePrivacySet *bool
AllowMessagesDispatch *bool
DefaultMappedModel *string
RequireOAuthOnly *bool
RequirePrivacySet *bool
MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
......@@ -908,7 +910,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
RequireOAuthOnly: input.RequireOAuthOnly,
RequirePrivacySet: input.RequirePrivacySet,
DefaultMappedModel: input.DefaultMappedModel,
MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
}
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
}
......@@ -1135,6 +1139,10 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.DefaultMappedModel != nil {
group.DefaultMappedModel = *input.DefaultMappedModel
}
if input.MessagesDispatchModelConfig != nil {
group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
}
sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
......
......@@ -10,6 +10,11 @@ import (
"github.com/stretchr/testify/require"
)
func ptrString[T ~string](v T) *string {
s := string(v)
return &s
}
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
type groupRepoStubForAdmin struct {
created *Group // 记录 Create 调用的参数
......@@ -261,6 +266,116 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K)
}
func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "dispatch-group",
Description: "dispatch config",
Platform: PlatformOpenAI,
RateMultiplier: 1.0,
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: " gpt-5.4-high ",
SonnetMappedModel: " gpt-5.3-codex ",
HaikuMappedModel: " gpt-5.4-mini-medium ",
ExactModelMappings: map[string]string{
" claude-sonnet-4-5-20250929 ": " gpt-5.2-high ",
},
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Equal(t, OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: "gpt-5.4-mini",
ExactModelMappings: map[string]string{
"claude-sonnet-4-5-20250929": "gpt-5.2",
},
}, repo.created.MessagesDispatchModelConfig)
}
func TestAdminService_UpdateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformOpenAI,
Status: StatusActive,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
MessagesDispatchModelConfig: &OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: " gpt-5.4-medium ",
ExactModelMappings: map[string]string{
" claude-haiku-4-5-20251001 ": " gpt-5.4-mini-high ",
},
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.4",
ExactModelMappings: map[string]string{
"claude-haiku-4-5-20251001": "gpt-5.4-mini",
},
}, repo.updated.MessagesDispatchModelConfig)
}
func TestAdminService_CreateGroup_ClearsMessagesDispatchFieldsForNonOpenAIPlatform(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "anthropic-group",
Description: "non-openai",
Platform: PlatformAnthropic,
RateMultiplier: 1.0,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
OpusMappedModel: "gpt-5.4",
},
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.False(t, repo.created.AllowMessagesDispatch)
require.Empty(t, repo.created.DefaultMappedModel)
require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.created.MessagesDispatchModelConfig)
}
func TestAdminService_UpdateGroup_ClearsMessagesDispatchFieldsWhenPlatformChangesAwayFromOpenAI(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-openai-group",
Platform: PlatformOpenAI,
Status: StatusActive,
AllowMessagesDispatch: true,
DefaultMappedModel: "gpt-5.4",
MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
SonnetMappedModel: "gpt-5.3-codex",
},
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
Platform: PlatformAnthropic,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, PlatformAnthropic, repo.updated.Platform)
require.False(t, repo.updated.AllowMessagesDispatch)
require.Empty(t, repo.updated.DefaultMappedModel)
require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.updated.MessagesDispatchModelConfig)
}
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
// 测试:
// 1. search 参数正常传递到 repository 层
......
......@@ -833,7 +833,8 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token
......
......@@ -71,6 +71,9 @@ const (
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
// Setting keys
const (
// 注册设置
......@@ -105,6 +108,30 @@ const (
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
// Generic OIDC OAuth 登录设置
SettingKeyOIDCConnectEnabled = "oidc_connect_enabled"
SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name"
SettingKeyOIDCConnectClientID = "oidc_connect_client_id"
SettingKeyOIDCConnectClientSecret = "oidc_connect_client_secret"
SettingKeyOIDCConnectIssuerURL = "oidc_connect_issuer_url"
SettingKeyOIDCConnectDiscoveryURL = "oidc_connect_discovery_url"
SettingKeyOIDCConnectAuthorizeURL = "oidc_connect_authorize_url"
SettingKeyOIDCConnectTokenURL = "oidc_connect_token_url"
SettingKeyOIDCConnectUserInfoURL = "oidc_connect_userinfo_url"
SettingKeyOIDCConnectJWKSURL = "oidc_connect_jwks_url"
SettingKeyOIDCConnectScopes = "oidc_connect_scopes"
SettingKeyOIDCConnectRedirectURL = "oidc_connect_redirect_url"
SettingKeyOIDCConnectFrontendRedirectURL = "oidc_connect_frontend_redirect_url"
SettingKeyOIDCConnectTokenAuthMethod = "oidc_connect_token_auth_method"
SettingKeyOIDCConnectUsePKCE = "oidc_connect_use_pkce"
SettingKeyOIDCConnectValidateIDToken = "oidc_connect_validate_id_token"
SettingKeyOIDCConnectAllowedSigningAlgs = "oidc_connect_allowed_signing_algs"
SettingKeyOIDCConnectClockSkewSeconds = "oidc_connect_clock_skew_seconds"
SettingKeyOIDCConnectRequireEmailVerified = "oidc_connect_require_email_verified"
SettingKeyOIDCConnectUserInfoEmailPath = "oidc_connect_userinfo_email_path"
SettingKeyOIDCConnectUserInfoIDPath = "oidc_connect_userinfo_id_path"
SettingKeyOIDCConnectUserInfoUsernamePath = "oidc_connect_userinfo_username_path"
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
......
......@@ -3,8 +3,12 @@ package service
import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
)
type OpenAIMessagesDispatchModelConfig = domain.OpenAIMessagesDispatchModelConfig
type Group struct {
ID int64
Name string
......@@ -49,10 +53,11 @@ type Group struct {
SortOrder int
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool
RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini)
RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini)
DefaultMappedModel string
AllowMessagesDispatch bool
RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini)
RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini)
DefaultMappedModel string
MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
CreatedAt time.Time
UpdatedAt time.Time
......
package service
import (
"bytes"
"fmt"
"strings"
"text/template"
)
type forcedCodexInstructionsTemplateData struct {
ExistingInstructions string
OriginalModel string
NormalizedModel string
BillingModel string
UpstreamModel string
}
func applyForcedCodexInstructionsTemplate(
reqBody map[string]any,
templateText string,
data forcedCodexInstructionsTemplateData,
) (bool, error) {
rendered, err := renderForcedCodexInstructionsTemplate(templateText, data)
if err != nil {
return false, err
}
if rendered == "" {
return false, nil
}
existing, _ := reqBody["instructions"].(string)
if strings.TrimSpace(existing) == rendered {
return false, nil
}
reqBody["instructions"] = rendered
return true, nil
}
func renderForcedCodexInstructionsTemplate(
templateText string,
data forcedCodexInstructionsTemplateData,
) (string, error) {
tmpl, err := template.New("forced_codex_instructions").Option("missingkey=zero").Parse(templateText)
if err != nil {
return "", fmt.Errorf("parse forced codex instructions template: %w", err)
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", fmt.Errorf("render forced codex instructions template: %w", err)
}
return strings.TrimSpace(buf.String()), nil
}
......@@ -6,9 +6,12 @@ import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
......@@ -127,3 +130,101 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T
t.Logf("upstream body: %s", string(upstream.lastBody))
t.Logf("response body: %s", rec.Body.String())
}
func TestForwardAsAnthropic_ForcedCodexInstructionsTemplatePrependsRenderedInstructions(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
templateDir := t.TempDir()
templatePath := filepath.Join(templateDir, "codex-instructions.md.tmpl")
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"system":"client-system","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_forced"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{
ForcedCodexInstructionsTemplateFile: templatePath,
ForcedCodexInstructionsTemplate: "server-prefix\n\n{{ .ExistingInstructions }}",
}},
httpUpstream: upstream,
}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "server-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
}
func TestForwardAsAnthropic_ForcedCodexInstructionsTemplateUsesCachedTemplateContent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
body := []byte(`{"model":"gpt-5.4","max_tokens":16,"system":"client-system","messages":[{"role":"user","content":"hello"}],"stream":false}`)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
upstreamBody := strings.Join([]string{
`data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_forced_cached"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &OpenAIGatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{
ForcedCodexInstructionsTemplateFile: "/path/that/should/not/be/read.tmpl",
ForcedCodexInstructionsTemplate: "cached-prefix\n\n{{ .ExistingInstructions }}",
}},
httpUpstream: upstream,
}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "oauth-token",
"chatgpt_account_id": "chatgpt-acc",
},
}
result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "cached-prefix\n\nclient-system", gjson.GetBytes(upstream.lastBody, "instructions").String())
}
......@@ -86,6 +86,24 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
return nil, fmt.Errorf("unmarshal for codex transform: %w", err)
}
codexResult := applyCodexOAuthTransform(reqBody, false, false)
forcedTemplateText := ""
if s.cfg != nil {
forcedTemplateText = s.cfg.Gateway.ForcedCodexInstructionsTemplate
}
templateUpstreamModel := upstreamModel
if codexResult.NormalizedModel != "" {
templateUpstreamModel = codexResult.NormalizedModel
}
existingInstructions, _ := reqBody["instructions"].(string)
if _, err := applyForcedCodexInstructionsTemplate(reqBody, forcedTemplateText, forcedCodexInstructionsTemplateData{
ExistingInstructions: strings.TrimSpace(existingInstructions),
OriginalModel: originalModel,
NormalizedModel: normalizedModel,
BillingModel: billingModel,
UpstreamModel: templateUpstreamModel,
}); err != nil {
return nil, err
}
if codexResult.NormalizedModel != "" {
upstreamModel = codexResult.NormalizedModel
}
......
package service
import "strings"
const (
defaultOpenAIMessagesDispatchOpusMappedModel = "gpt-5.4"
defaultOpenAIMessagesDispatchSonnetMappedModel = "gpt-5.3-codex"
defaultOpenAIMessagesDispatchHaikuMappedModel = "gpt-5.4-mini"
)
func normalizeOpenAIMessagesDispatchMappedModel(model string) string {
model = NormalizeOpenAICompatRequestedModel(strings.TrimSpace(model))
return strings.TrimSpace(model)
}
func normalizeOpenAIMessagesDispatchModelConfig(cfg OpenAIMessagesDispatchModelConfig) OpenAIMessagesDispatchModelConfig {
out := OpenAIMessagesDispatchModelConfig{
OpusMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.OpusMappedModel),
SonnetMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.SonnetMappedModel),
HaikuMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.HaikuMappedModel),
}
if len(cfg.ExactModelMappings) > 0 {
out.ExactModelMappings = make(map[string]string, len(cfg.ExactModelMappings))
for requestedModel, mappedModel := range cfg.ExactModelMappings {
requestedModel = strings.TrimSpace(requestedModel)
mappedModel = normalizeOpenAIMessagesDispatchMappedModel(mappedModel)
if requestedModel == "" || mappedModel == "" {
continue
}
out.ExactModelMappings[requestedModel] = mappedModel
}
if len(out.ExactModelMappings) == 0 {
out.ExactModelMappings = nil
}
}
return out
}
func claudeMessagesDispatchFamily(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
if !strings.HasPrefix(normalized, "claude") {
return ""
}
switch {
case strings.Contains(normalized, "opus"):
return "opus"
case strings.Contains(normalized, "sonnet"):
return "sonnet"
case strings.Contains(normalized, "haiku"):
return "haiku"
default:
return ""
}
}
func (g *Group) ResolveMessagesDispatchModel(requestedModel string) string {
if g == nil {
return ""
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return ""
}
cfg := normalizeOpenAIMessagesDispatchModelConfig(g.MessagesDispatchModelConfig)
if mappedModel := strings.TrimSpace(cfg.ExactModelMappings[requestedModel]); mappedModel != "" {
return mappedModel
}
switch claudeMessagesDispatchFamily(requestedModel) {
case "opus":
if mappedModel := strings.TrimSpace(cfg.OpusMappedModel); mappedModel != "" {
return mappedModel
}
return defaultOpenAIMessagesDispatchOpusMappedModel
case "sonnet":
if mappedModel := strings.TrimSpace(cfg.SonnetMappedModel); mappedModel != "" {
return mappedModel
}
return defaultOpenAIMessagesDispatchSonnetMappedModel
case "haiku":
if mappedModel := strings.TrimSpace(cfg.HaikuMappedModel); mappedModel != "" {
return mappedModel
}
return defaultOpenAIMessagesDispatchHaikuMappedModel
default:
return ""
}
}
func sanitizeGroupMessagesDispatchFields(g *Group) {
if g == nil || g.Platform == PlatformOpenAI {
return
}
g.AllowMessagesDispatch = false
g.DefaultMappedModel = ""
g.MessagesDispatchModelConfig = OpenAIMessagesDispatchModelConfig{}
}
package service
import "testing"
import "github.com/stretchr/testify/require"
func TestNormalizeOpenAIMessagesDispatchModelConfig(t *testing.T) {
t.Parallel()
cfg := normalizeOpenAIMessagesDispatchModelConfig(OpenAIMessagesDispatchModelConfig{
OpusMappedModel: " gpt-5.4-high ",
SonnetMappedModel: "gpt-5.3-codex",
HaikuMappedModel: " gpt-5.4-mini-medium ",
ExactModelMappings: map[string]string{
" claude-sonnet-4-5-20250929 ": " gpt-5.2-high ",
"": "gpt-5.4",
"claude-opus-4-6": " ",
},
})
require.Equal(t, "gpt-5.4", cfg.OpusMappedModel)
require.Equal(t, "gpt-5.3-codex", cfg.SonnetMappedModel)
require.Equal(t, "gpt-5.4-mini", cfg.HaikuMappedModel)
require.Equal(t, map[string]string{
"claude-sonnet-4-5-20250929": "gpt-5.2",
}, cfg.ExactModelMappings)
}
......@@ -16,7 +16,7 @@ import (
var ErrOpsDisabled = infraerrors.NotFound("OPS_DISABLED", "Ops monitoring is disabled")
const (
opsMaxStoredRequestBodyBytes = 10 * 1024
opsMaxStoredRequestBodyBytes = 256 * 1024
opsMaxStoredErrorBodyBytes = 20 * 1024
)
......
......@@ -17,6 +17,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/imroc/req/v3"
"golang.org/x/sync/singleflight"
)
......@@ -167,6 +168,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled,
SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
......@@ -180,6 +183,19 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
} else {
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
}
oidcEnabled := false
if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok {
oidcEnabled = raw == "true"
} else {
oidcEnabled = s.cfg != nil && s.cfg.OIDC.Enabled
}
oidcProviderName := strings.TrimSpace(settings[SettingKeyOIDCConnectProviderName])
if oidcProviderName == "" && s.cfg != nil {
oidcProviderName = strings.TrimSpace(s.cfg.OIDC.ProviderName)
}
if oidcProviderName == "" {
oidcProviderName = "OIDC"
}
// Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
......@@ -218,6 +234,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
OIDCOAuthProviderName: oidcProviderName,
}, nil
}
......@@ -267,6 +285,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
Version string `json:"version,omitempty"`
}{
RegistrationEnabled: settings.RegistrationEnabled,
......@@ -294,6 +314,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
Version: s.version,
}, nil
}
......@@ -346,8 +368,8 @@ func safeRawJSONArray(raw string) json.RawMessage {
return json.RawMessage("[]")
}
// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url
// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection.
// GetFrameSrcOrigins returns deduplicated http(s) origins from home_content URL,
// purchase_subscription_url, and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection.
func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) {
settings, err := s.GetPublicSettings(ctx)
if err != nil {
......@@ -366,6 +388,9 @@ func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, erro
}
}
// home content URL (when home_content is set to a URL for iframe embedding)
addOrigin(settings.HomeContent)
// purchase subscription URL
if settings.PurchaseSubscriptionEnabled {
addOrigin(settings.PurchaseSubscriptionURL)
......@@ -473,6 +498,32 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret
}
// Generic OIDC OAuth 登录
updates[SettingKeyOIDCConnectEnabled] = strconv.FormatBool(settings.OIDCConnectEnabled)
updates[SettingKeyOIDCConnectProviderName] = settings.OIDCConnectProviderName
updates[SettingKeyOIDCConnectClientID] = settings.OIDCConnectClientID
updates[SettingKeyOIDCConnectIssuerURL] = settings.OIDCConnectIssuerURL
updates[SettingKeyOIDCConnectDiscoveryURL] = settings.OIDCConnectDiscoveryURL
updates[SettingKeyOIDCConnectAuthorizeURL] = settings.OIDCConnectAuthorizeURL
updates[SettingKeyOIDCConnectTokenURL] = settings.OIDCConnectTokenURL
updates[SettingKeyOIDCConnectUserInfoURL] = settings.OIDCConnectUserInfoURL
updates[SettingKeyOIDCConnectJWKSURL] = settings.OIDCConnectJWKSURL
updates[SettingKeyOIDCConnectScopes] = settings.OIDCConnectScopes
updates[SettingKeyOIDCConnectRedirectURL] = settings.OIDCConnectRedirectURL
updates[SettingKeyOIDCConnectFrontendRedirectURL] = settings.OIDCConnectFrontendRedirectURL
updates[SettingKeyOIDCConnectTokenAuthMethod] = settings.OIDCConnectTokenAuthMethod
updates[SettingKeyOIDCConnectUsePKCE] = strconv.FormatBool(settings.OIDCConnectUsePKCE)
updates[SettingKeyOIDCConnectValidateIDToken] = strconv.FormatBool(settings.OIDCConnectValidateIDToken)
updates[SettingKeyOIDCConnectAllowedSigningAlgs] = settings.OIDCConnectAllowedSigningAlgs
updates[SettingKeyOIDCConnectClockSkewSeconds] = strconv.Itoa(settings.OIDCConnectClockSkewSeconds)
updates[SettingKeyOIDCConnectRequireEmailVerified] = strconv.FormatBool(settings.OIDCConnectRequireEmailVerified)
updates[SettingKeyOIDCConnectUserInfoEmailPath] = settings.OIDCConnectUserInfoEmailPath
updates[SettingKeyOIDCConnectUserInfoIDPath] = settings.OIDCConnectUserInfoIDPath
updates[SettingKeyOIDCConnectUserInfoUsernamePath] = settings.OIDCConnectUserInfoUsernamePath
if settings.OIDCConnectClientSecret != "" {
updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret
}
// OEM设置
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
......@@ -851,6 +902,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyTablePageSizeOptions: "[10,20,50,100]",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyOIDCConnectEnabled: "false",
SettingKeyOIDCConnectProviderName: "OIDC",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]",
......@@ -980,6 +1033,138 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != ""
// Generic OIDC 设置:
// - 兼容 config.yaml/env
// - 支持后台系统设置覆盖并持久化(存储于 DB)
oidcBase := config.OIDCConnectConfig{}
if s.cfg != nil {
oidcBase = s.cfg.OIDC
}
if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok {
result.OIDCConnectEnabled = raw == "true"
} else {
result.OIDCConnectEnabled = oidcBase.Enabled
}
if v, ok := settings[SettingKeyOIDCConnectProviderName]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectProviderName = strings.TrimSpace(v)
} else {
result.OIDCConnectProviderName = strings.TrimSpace(oidcBase.ProviderName)
}
if result.OIDCConnectProviderName == "" {
result.OIDCConnectProviderName = "OIDC"
}
if v, ok := settings[SettingKeyOIDCConnectClientID]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectClientID = strings.TrimSpace(v)
} else {
result.OIDCConnectClientID = strings.TrimSpace(oidcBase.ClientID)
}
if v, ok := settings[SettingKeyOIDCConnectIssuerURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectIssuerURL = strings.TrimSpace(v)
} else {
result.OIDCConnectIssuerURL = strings.TrimSpace(oidcBase.IssuerURL)
}
if v, ok := settings[SettingKeyOIDCConnectDiscoveryURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectDiscoveryURL = strings.TrimSpace(v)
} else {
result.OIDCConnectDiscoveryURL = strings.TrimSpace(oidcBase.DiscoveryURL)
}
if v, ok := settings[SettingKeyOIDCConnectAuthorizeURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAuthorizeURL = strings.TrimSpace(v)
} else {
result.OIDCConnectAuthorizeURL = strings.TrimSpace(oidcBase.AuthorizeURL)
}
if v, ok := settings[SettingKeyOIDCConnectTokenURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectTokenURL = strings.TrimSpace(v)
} else {
result.OIDCConnectTokenURL = strings.TrimSpace(oidcBase.TokenURL)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectUserInfoURL = strings.TrimSpace(v)
} else {
result.OIDCConnectUserInfoURL = strings.TrimSpace(oidcBase.UserInfoURL)
}
if v, ok := settings[SettingKeyOIDCConnectJWKSURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectJWKSURL = strings.TrimSpace(v)
} else {
result.OIDCConnectJWKSURL = strings.TrimSpace(oidcBase.JWKSURL)
}
if v, ok := settings[SettingKeyOIDCConnectScopes]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectScopes = strings.TrimSpace(v)
} else {
result.OIDCConnectScopes = strings.TrimSpace(oidcBase.Scopes)
}
if v, ok := settings[SettingKeyOIDCConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectRedirectURL = strings.TrimSpace(v)
} else {
result.OIDCConnectRedirectURL = strings.TrimSpace(oidcBase.RedirectURL)
}
if v, ok := settings[SettingKeyOIDCConnectFrontendRedirectURL]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectFrontendRedirectURL = strings.TrimSpace(v)
} else {
result.OIDCConnectFrontendRedirectURL = strings.TrimSpace(oidcBase.FrontendRedirectURL)
}
if v, ok := settings[SettingKeyOIDCConnectTokenAuthMethod]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(v))
} else {
result.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(oidcBase.TokenAuthMethod))
}
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
result.OIDCConnectUsePKCE = raw == "true"
} else {
result.OIDCConnectUsePKCE = oidcBase.UsePKCE
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
result.OIDCConnectValidateIDToken = raw == "true"
} else {
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
}
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
} else {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(oidcBase.AllowedSigningAlgs)
}
clockSkewSet := false
if raw, ok := settings[SettingKeyOIDCConnectClockSkewSeconds]; ok && strings.TrimSpace(raw) != "" {
if parsed, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil {
result.OIDCConnectClockSkewSeconds = parsed
clockSkewSet = true
}
}
if !clockSkewSet {
result.OIDCConnectClockSkewSeconds = oidcBase.ClockSkewSeconds
}
if !clockSkewSet && result.OIDCConnectClockSkewSeconds == 0 {
result.OIDCConnectClockSkewSeconds = 120
}
if raw, ok := settings[SettingKeyOIDCConnectRequireEmailVerified]; ok {
result.OIDCConnectRequireEmailVerified = raw == "true"
} else {
result.OIDCConnectRequireEmailVerified = oidcBase.RequireEmailVerified
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoEmailPath]; ok {
result.OIDCConnectUserInfoEmailPath = strings.TrimSpace(v)
} else {
result.OIDCConnectUserInfoEmailPath = strings.TrimSpace(oidcBase.UserInfoEmailPath)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoIDPath]; ok {
result.OIDCConnectUserInfoIDPath = strings.TrimSpace(v)
} else {
result.OIDCConnectUserInfoIDPath = strings.TrimSpace(oidcBase.UserInfoIDPath)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoUsernamePath]; ok {
result.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(v)
} else {
result.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(oidcBase.UserInfoUsernamePath)
}
result.OIDCConnectClientSecret = strings.TrimSpace(settings[SettingKeyOIDCConnectClientSecret])
if result.OIDCConnectClientSecret == "" {
result.OIDCConnectClientSecret = strings.TrimSpace(oidcBase.ClientSecret)
}
result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != ""
// Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
......@@ -1396,6 +1581,282 @@ func (s *SettingService) SetOverloadCooldownSettings(ctx context.Context, settin
return s.settingRepo.Set(ctx, SettingKeyOverloadCooldownSettings, string(data))
}
// GetOIDCConnectOAuthConfig 返回用于登录的“最终生效” OIDC 配置。
//
// 优先级:
// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值
// - 否则回退到 config.yaml/env 的值
func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.OIDCConnectConfig, error) {
if s == nil || s.cfg == nil {
return config.OIDCConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
effective := s.cfg.OIDC
keys := []string{
SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName,
SettingKeyOIDCConnectClientID,
SettingKeyOIDCConnectClientSecret,
SettingKeyOIDCConnectIssuerURL,
SettingKeyOIDCConnectDiscoveryURL,
SettingKeyOIDCConnectAuthorizeURL,
SettingKeyOIDCConnectTokenURL,
SettingKeyOIDCConnectUserInfoURL,
SettingKeyOIDCConnectJWKSURL,
SettingKeyOIDCConnectScopes,
SettingKeyOIDCConnectRedirectURL,
SettingKeyOIDCConnectFrontendRedirectURL,
SettingKeyOIDCConnectTokenAuthMethod,
SettingKeyOIDCConnectUsePKCE,
SettingKeyOIDCConnectValidateIDToken,
SettingKeyOIDCConnectAllowedSigningAlgs,
SettingKeyOIDCConnectClockSkewSeconds,
SettingKeyOIDCConnectRequireEmailVerified,
SettingKeyOIDCConnectUserInfoEmailPath,
SettingKeyOIDCConnectUserInfoIDPath,
SettingKeyOIDCConnectUserInfoUsernamePath,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return config.OIDCConnectConfig{}, fmt.Errorf("get oidc connect settings: %w", err)
}
if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok {
effective.Enabled = raw == "true"
}
if v, ok := settings[SettingKeyOIDCConnectProviderName]; ok && strings.TrimSpace(v) != "" {
effective.ProviderName = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectClientID]; ok && strings.TrimSpace(v) != "" {
effective.ClientID = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectClientSecret]; ok && strings.TrimSpace(v) != "" {
effective.ClientSecret = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectIssuerURL]; ok && strings.TrimSpace(v) != "" {
effective.IssuerURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectDiscoveryURL]; ok && strings.TrimSpace(v) != "" {
effective.DiscoveryURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectAuthorizeURL]; ok && strings.TrimSpace(v) != "" {
effective.AuthorizeURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectTokenURL]; ok && strings.TrimSpace(v) != "" {
effective.TokenURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoURL]; ok && strings.TrimSpace(v) != "" {
effective.UserInfoURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectJWKSURL]; ok && strings.TrimSpace(v) != "" {
effective.JWKSURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectScopes]; ok && strings.TrimSpace(v) != "" {
effective.Scopes = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.RedirectURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectFrontendRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.FrontendRedirectURL = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectTokenAuthMethod]; ok && strings.TrimSpace(v) != "" {
effective.TokenAuthMethod = strings.ToLower(strings.TrimSpace(v))
}
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
effective.UsePKCE = raw == "true"
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true"
}
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v)
}
if raw, ok := settings[SettingKeyOIDCConnectClockSkewSeconds]; ok && strings.TrimSpace(raw) != "" {
if parsed, parseErr := strconv.Atoi(strings.TrimSpace(raw)); parseErr == nil {
effective.ClockSkewSeconds = parsed
}
}
if raw, ok := settings[SettingKeyOIDCConnectRequireEmailVerified]; ok {
effective.RequireEmailVerified = raw == "true"
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoEmailPath]; ok {
effective.UserInfoEmailPath = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoIDPath]; ok {
effective.UserInfoIDPath = strings.TrimSpace(v)
}
if v, ok := settings[SettingKeyOIDCConnectUserInfoUsernamePath]; ok {
effective.UserInfoUsernamePath = strings.TrimSpace(v)
}
if !effective.Enabled {
return config.OIDCConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
if strings.TrimSpace(effective.ProviderName) == "" {
effective.ProviderName = "OIDC"
}
if strings.TrimSpace(effective.ClientID) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured")
}
if strings.TrimSpace(effective.IssuerURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth issuer url not configured")
}
if strings.TrimSpace(effective.RedirectURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")
}
if strings.TrimSpace(effective.FrontendRedirectURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url not configured")
}
if !scopesContainOpenID(effective.Scopes) {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth scopes must contain openid")
}
if effective.ClockSkewSeconds < 0 || effective.ClockSkewSeconds > 600 {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth clock skew must be between 0 and 600")
}
if err := config.ValidateAbsoluteHTTPURL(effective.IssuerURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth issuer url invalid")
}
discoveryURL := strings.TrimSpace(effective.DiscoveryURL)
if discoveryURL == "" {
discoveryURL = oidcDefaultDiscoveryURL(effective.IssuerURL)
effective.DiscoveryURL = discoveryURL
}
if discoveryURL != "" {
if err := config.ValidateAbsoluteHTTPURL(discoveryURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth discovery url invalid")
}
}
needsDiscovery := strings.TrimSpace(effective.AuthorizeURL) == "" ||
strings.TrimSpace(effective.TokenURL) == "" ||
(effective.ValidateIDToken && strings.TrimSpace(effective.JWKSURL) == "")
if needsDiscovery && discoveryURL != "" {
metadata, resolveErr := oidcResolveProviderMetadata(ctx, discoveryURL)
if resolveErr != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth discovery resolve failed").WithCause(resolveErr)
}
if strings.TrimSpace(effective.AuthorizeURL) == "" {
effective.AuthorizeURL = strings.TrimSpace(metadata.AuthorizationEndpoint)
}
if strings.TrimSpace(effective.TokenURL) == "" {
effective.TokenURL = strings.TrimSpace(metadata.TokenEndpoint)
}
if strings.TrimSpace(effective.UserInfoURL) == "" {
effective.UserInfoURL = strings.TrimSpace(metadata.UserInfoEndpoint)
}
if strings.TrimSpace(effective.JWKSURL) == "" {
effective.JWKSURL = strings.TrimSpace(metadata.JWKSURI)
}
}
if strings.TrimSpace(effective.AuthorizeURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url not configured")
}
if strings.TrimSpace(effective.TokenURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url not configured")
}
if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url invalid")
}
if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url invalid")
}
if v := strings.TrimSpace(effective.UserInfoURL); v != "" {
if err := config.ValidateAbsoluteHTTPURL(v); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url invalid")
}
}
if effective.ValidateIDToken {
if strings.TrimSpace(effective.JWKSURL) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth jwks url not configured")
}
if strings.TrimSpace(effective.AllowedSigningAlgs) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth signing algs not configured")
}
}
if v := strings.TrimSpace(effective.JWKSURL); v != "" {
if err := config.ValidateAbsoluteHTTPURL(v); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth jwks url invalid")
}
}
if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url invalid")
}
if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid")
}
method := strings.ToLower(strings.TrimSpace(effective.TokenAuthMethod))
switch method {
case "", "client_secret_post", "client_secret_basic":
if strings.TrimSpace(effective.ClientSecret) == "" {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
if !effective.UsePKCE {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
}
default:
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
return effective, nil
}
func scopesContainOpenID(scopes string) bool {
for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) {
if scope == "openid" {
return true
}
}
return false
}
type oidcProviderMetadata struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint"`
JWKSURI string `json:"jwks_uri"`
}
func oidcDefaultDiscoveryURL(issuerURL string) string {
issuerURL = strings.TrimSpace(issuerURL)
if issuerURL == "" {
return ""
}
return strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration"
}
func oidcResolveProviderMetadata(ctx context.Context, discoveryURL string) (*oidcProviderMetadata, error) {
discoveryURL = strings.TrimSpace(discoveryURL)
if discoveryURL == "" {
return nil, fmt.Errorf("discovery url is empty")
}
resp, err := req.C().
SetTimeout(15*time.Second).
R().
SetContext(ctx).
SetHeader("Accept", "application/json").
Get(discoveryURL)
if err != nil {
return nil, fmt.Errorf("request discovery document: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("discovery request failed: status=%d", resp.StatusCode)
}
metadata := &oidcProviderMetadata{}
if err := json.Unmarshal(resp.Bytes(), metadata); err != nil {
return nil, fmt.Errorf("parse discovery document: %w", err)
}
return metadata, nil
}
// GetStreamTimeoutSettings 获取流超时处理配置
func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings)
......
//go:build unit
package service
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type settingOIDCRepoStub struct {
values map[string]string
}
func (s *settingOIDCRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *settingOIDCRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *settingOIDCRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *settingOIDCRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *settingOIDCRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *settingOIDCRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *settingOIDCRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestGetOIDCConnectOAuthConfig_ResolvesEndpointsFromIssuerDiscovery(t *testing.T) {
var discoveryHits int
var baseURL string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/issuer/.well-known/openid-configuration" {
http.NotFound(w, r)
return
}
discoveryHits++
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(fmt.Sprintf(`{
"authorization_endpoint":"%s/issuer/protocol/openid-connect/auth",
"token_endpoint":"%s/issuer/protocol/openid-connect/token",
"userinfo_endpoint":"%s/issuer/protocol/openid-connect/userinfo",
"jwks_uri":"%s/issuer/protocol/openid-connect/certs"
}`, baseURL, baseURL, baseURL, baseURL)))
}))
defer srv.Close()
baseURL = srv.URL
cfg := &config.Config{
OIDC: config.OIDCConnectConfig{
Enabled: true,
ProviderName: "OIDC",
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: srv.URL + "/issuer",
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
Scopes: "openid email profile",
TokenAuthMethod: "client_secret_post",
ValidateIDToken: true,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
},
}
repo := &settingOIDCRepoStub{values: map[string]string{}}
svc := NewSettingService(repo, cfg)
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
require.NoError(t, err)
require.Equal(t, 1, discoveryHits)
require.Equal(t, srv.URL+"/issuer/.well-known/openid-configuration", got.DiscoveryURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/auth", got.AuthorizeURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/token", got.TokenURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/userinfo", got.UserInfoURL)
require.Equal(t, srv.URL+"/issuer/protocol/openid-connect/certs", got.JWKSURL)
}
......@@ -31,6 +31,31 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool
LinuxDoConnectRedirectURL string
// Generic OIDC OAuth 登录
OIDCConnectEnabled bool
OIDCConnectProviderName string
OIDCConnectClientID string
OIDCConnectClientSecret string
OIDCConnectClientSecretConfigured bool
OIDCConnectIssuerURL string
OIDCConnectDiscoveryURL string
OIDCConnectAuthorizeURL string
OIDCConnectTokenURL string
OIDCConnectUserInfoURL string
OIDCConnectJWKSURL string
OIDCConnectScopes string
OIDCConnectRedirectURL string
OIDCConnectFrontendRedirectURL string
OIDCConnectTokenAuthMethod string
OIDCConnectUsePKCE bool
OIDCConnectValidateIDToken bool
OIDCConnectAllowedSigningAlgs string
OIDCConnectClockSkewSeconds int
OIDCConnectRequireEmailVerified bool
OIDCConnectUserInfoEmailPath string
OIDCConnectUserInfoIDPath string
OIDCConnectUserInfoUsernamePath string
SiteName string
SiteLogo string
SiteSubtitle string
......@@ -114,9 +139,11 @@ type PublicSettings struct {
CustomMenuItems string // JSON array of custom menu items
CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool
BackendModeEnabled bool
Version string
LinuxDoOAuthEnabled bool
BackendModeEnabled bool
OIDCOAuthEnabled bool
OIDCOAuthProviderName string
Version string
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment