Unverified Commit a1dc0089 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #944 from miraserver/feat/backend-mode

feat: add Backend Mode toggle to disable user self-service
parents dfbcc363 6826149a
...@@ -125,6 +125,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -125,6 +125,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: settings.MinClaudeCodeVersion, MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling, AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
BackendModeEnabled: settings.BackendModeEnabled,
}) })
} }
...@@ -199,6 +200,9 @@ type UpdateSettingsRequest struct { ...@@ -199,6 +200,9 @@ type UpdateSettingsRequest struct {
// 分组隔离 // 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"` AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
// Backend Mode
BackendModeEnabled bool `json:"backend_mode_enabled"`
} }
// UpdateSettings 更新系统设置 // UpdateSettings 更新系统设置
...@@ -473,6 +477,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -473,6 +477,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
IdentityPatchPrompt: req.IdentityPatchPrompt, IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion, MinClaudeCodeVersion: req.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling, AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
BackendModeEnabled: req.BackendModeEnabled,
OpsMonitoringEnabled: func() bool { OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil { if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled return *req.OpsMonitoringEnabled
...@@ -571,6 +576,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -571,6 +576,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling, AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
BackendModeEnabled: updatedSettings.BackendModeEnabled,
}) })
} }
...@@ -725,6 +731,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -725,6 +731,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling { if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
changed = append(changed, "allow_ungrouped_key_scheduling") changed = append(changed, "allow_ungrouped_key_scheduling")
} }
if before.BackendModeEnabled != after.BackendModeEnabled {
changed = append(changed, "backend_mode_enabled")
}
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled { if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
changed = append(changed, "purchase_subscription_enabled") changed = append(changed, "purchase_subscription_enabled")
} }
......
...@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) { ...@@ -194,6 +194,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
return return
} }
// Backend mode: only admin can login
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
h.respondWithTokenPair(c, user) h.respondWithTokenPair(c, user)
} }
...@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { ...@@ -250,16 +256,22 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
return return
} }
// Delete the login session // Get the user (before session deletion so we can check backend mode)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
// Get the user
user, err := h.userService.GetByID(c.Request.Context(), session.UserID) user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
// Backend mode: only admin can login (check BEFORE deleting session)
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
// Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
h.respondWithTokenPair(c, user) h.respondWithTokenPair(c, user)
} }
...@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) { ...@@ -522,16 +534,22 @@ func (h *AuthHandler) RefreshToken(c *gin.Context) {
return return
} }
tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken) result, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
// Backend mode: block non-admin token refresh
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && result.UserRole != "admin" {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
response.Success(c, RefreshTokenResponse{ response.Success(c, RefreshTokenResponse{
AccessToken: tokenPair.AccessToken, AccessToken: result.AccessToken,
RefreshToken: tokenPair.RefreshToken, RefreshToken: result.RefreshToken,
ExpiresIn: tokenPair.ExpiresIn, ExpiresIn: result.ExpiresIn,
TokenType: "Bearer", TokenType: "Bearer",
}) })
} }
......
...@@ -81,6 +81,9 @@ type SystemSettings struct { ...@@ -81,6 +81,9 @@ type SystemSettings struct {
// 分组隔离 // 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"` AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
// Backend Mode
BackendModeEnabled bool `json:"backend_mode_enabled"`
} }
type DefaultSubscriptionSetting struct { type DefaultSubscriptionSetting struct {
...@@ -111,6 +114,7 @@ type PublicSettings struct { ...@@ -111,6 +114,7 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version"` Version string `json:"version"`
} }
......
...@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -54,6 +54,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
Version: h.version, Version: h.version,
}) })
} }
...@@ -537,6 +537,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -537,6 +537,7 @@ func TestAPIContracts(t *testing.T) {
"purchase_subscription_url": "", "purchase_subscription_url": "",
"min_claude_code_version": "", "min_claude_code_version": "",
"allow_ungrouped_key_scheduling": false, "allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"custom_menu_items": [] "custom_menu_items": []
} }
}`, }`,
......
package middleware
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
role, _ := GetUserRoleFromContext(c)
if role == "admin" {
c.Next()
return
}
response.Forbidden(c, "Backend mode is active. User self-service is disabled.")
c.Abort()
}
}
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
// Allows: login, login/2fa, logout, refresh (admin needs these).
// Blocks: register, forgot-password, reset-password, OAuth, etc.
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
path := c.Request.URL.Path
// Allow login, 2FA, logout, refresh, public settings
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
for _, suffix := range allowedSuffixes {
if strings.HasSuffix(path, suffix) {
c.Next()
return
}
}
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
c.Abort()
}
}
//go:build unit
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type bmSettingRepo struct {
values map[string]string
}
func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) {
v, ok := r.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return v, nil
}
func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error {
panic("unexpected Set call")
}
func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
if r.values == nil {
r.values = make(map[string]string, len(settings))
}
for key, value := range settings {
r.values[key] = value
}
return nil
}
func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (r *bmSettingRepo) Delete(_ context.Context, _ string) error {
panic("unexpected Delete call")
}
func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService {
t.Helper()
repo := &bmSettingRepo{
values: map[string]string{
service.SettingKeyBackendModeEnabled: enabled,
},
}
svc := service.NewSettingService(repo, &config.Config{})
require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{
BackendModeEnabled: enabled == "true",
}))
return svc
}
func stringPtr(v string) *string {
return &v
}
func TestBackendModeUserGuard(t *testing.T) {
tests := []struct {
name string
nilService bool
enabled string
role *string
wantStatus int
}{
{
name: "disabled_allows_all",
enabled: "false",
role: stringPtr("user"),
wantStatus: http.StatusOK,
},
{
name: "nil_service_allows_all",
nilService: true,
role: stringPtr("user"),
wantStatus: http.StatusOK,
},
{
name: "enabled_admin_allowed",
enabled: "true",
role: stringPtr("admin"),
wantStatus: http.StatusOK,
},
{
name: "enabled_user_blocked",
enabled: "true",
role: stringPtr("user"),
wantStatus: http.StatusForbidden,
},
{
name: "enabled_no_role_blocked",
enabled: "true",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_empty_role_blocked",
enabled: "true",
role: stringPtr(""),
wantStatus: http.StatusForbidden,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
if tc.role != nil {
role := *tc.role
r.Use(func(c *gin.Context) {
c.Set(string(ContextKeyUserRole), role)
c.Next()
})
}
var svc *service.SettingService
if !tc.nilService {
svc = newBackendModeSettingService(t, tc.enabled)
}
r.Use(BackendModeUserGuard(svc))
r.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
r.ServeHTTP(w, req)
require.Equal(t, tc.wantStatus, w.Code)
})
}
}
func TestBackendModeAuthGuard(t *testing.T) {
tests := []struct {
name string
nilService bool
enabled string
path string
wantStatus int
}{
{
name: "disabled_allows_all",
enabled: "false",
path: "/api/v1/auth/register",
wantStatus: http.StatusOK,
},
{
name: "nil_service_allows_all",
nilService: true,
path: "/api/v1/auth/register",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_login",
enabled: "true",
path: "/api/v1/auth/login",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_login_2fa",
enabled: "true",
path: "/api/v1/auth/login/2fa",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_logout",
enabled: "true",
path: "/api/v1/auth/logout",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_refresh",
enabled: "true",
path: "/api/v1/auth/refresh",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_register",
enabled: "true",
path: "/api/v1/auth/register",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_blocks_forgot_password",
enabled: "true",
path: "/api/v1/auth/forgot-password",
wantStatus: http.StatusForbidden,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
var svc *service.SettingService
if !tc.nilService {
svc = newBackendModeSettingService(t, tc.enabled)
}
r.Use(BackendModeAuthGuard(svc))
r.Any("/*path", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
r.ServeHTTP(w, req)
require.Equal(t, tc.wantStatus, w.Code)
})
}
}
...@@ -107,9 +107,9 @@ func registerRoutes( ...@@ -107,9 +107,9 @@ func registerRoutes(
v1 := r.Group("/api/v1") v1 := r.Group("/api/v1")
// 注册各模块路由 // 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient) routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
routes.RegisterUserRoutes(v1, h, jwtAuth) routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth) routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
} }
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/middleware" "github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
...@@ -17,12 +18,14 @@ func RegisterAuthRoutes( ...@@ -17,12 +18,14 @@ func RegisterAuthRoutes(
h *handler.Handlers, h *handler.Handlers,
jwtAuth servermiddleware.JWTAuthMiddleware, jwtAuth servermiddleware.JWTAuthMiddleware,
redisClient *redis.Client, redisClient *redis.Client,
settingService *service.SettingService,
) { ) {
// 创建速率限制器 // 创建速率限制器
rateLimiter := middleware.NewRateLimiter(redisClient) rateLimiter := middleware.NewRateLimiter(redisClient)
// 公开接口 // 公开接口
auth := v1.Group("/auth") auth := v1.Group("/auth")
auth.Use(servermiddleware.BackendModeAuthGuard(settingService))
{ {
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close) // 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{ auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
...@@ -78,6 +81,7 @@ func RegisterAuthRoutes( ...@@ -78,6 +81,7 @@ func RegisterAuthRoutes(
// 需要认证的当前用户信息 // 需要认证的当前用户信息
authenticated := v1.Group("") authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth)) authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(servermiddleware.BackendModeUserGuard(settingService))
{ {
authenticated.GET("/auth/me", h.Auth.GetCurrentUser) authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证) // 撤销所有会话(需要认证)
......
...@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine { ...@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
c.Next() c.Next()
}), }),
redisClient, redisClient,
nil,
) )
return router return router
......
...@@ -3,6 +3,7 @@ package routes ...@@ -3,6 +3,7 @@ package routes
import ( import (
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes( ...@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes(
v1 *gin.RouterGroup, v1 *gin.RouterGroup,
h *handler.Handlers, h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware, jwtAuth middleware.JWTAuthMiddleware,
settingService *service.SettingService,
) { ) {
if h.SoraClient == nil { if h.SoraClient == nil {
return return
...@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes( ...@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes(
authenticated := v1.Group("/sora") authenticated := v1.Group("/sora")
authenticated.Use(gin.HandlerFunc(jwtAuth)) authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{ {
authenticated.POST("/generate", h.SoraClient.Generate) authenticated.POST("/generate", h.SoraClient.Generate)
authenticated.GET("/generations", h.SoraClient.ListGenerations) authenticated.GET("/generations", h.SoraClient.ListGenerations)
......
...@@ -3,6 +3,7 @@ package routes ...@@ -3,6 +3,7 @@ package routes
import ( import (
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -12,9 +13,11 @@ func RegisterUserRoutes( ...@@ -12,9 +13,11 @@ func RegisterUserRoutes(
v1 *gin.RouterGroup, v1 *gin.RouterGroup,
h *handler.Handlers, h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware, jwtAuth middleware.JWTAuthMiddleware,
settingService *service.SettingService,
) { ) {
authenticated := v1.Group("") authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth)) authenticated.Use(gin.HandlerFunc(jwtAuth))
authenticated.Use(middleware.BackendModeUserGuard(settingService))
{ {
// 用户接口 // 用户接口
user := authenticated.Group("/user") user := authenticated.Group("/user")
......
...@@ -1087,6 +1087,12 @@ type TokenPair struct { ...@@ -1087,6 +1087,12 @@ type TokenPair struct {
ExpiresIn int `json:"expires_in"` // Access Token有效期(秒) ExpiresIn int `json:"expires_in"` // Access Token有效期(秒)
} }
// TokenPairWithUser extends TokenPair with user role for backend mode checks
type TokenPairWithUser struct {
TokenPair
UserRole string
}
// GenerateTokenPair 生成Access Token和Refresh Token对 // GenerateTokenPair 生成Access Token和Refresh Token对
// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系 // familyID: 可选的Token家族ID,用于Token轮转时保持家族关系
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) { func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
...@@ -1168,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami ...@@ -1168,7 +1174,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
// RefreshTokenPair 使用Refresh Token刷新Token对 // RefreshTokenPair 使用Refresh Token刷新Token对
// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效 // 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) { func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPairWithUser, error) {
// 检查 refreshTokenCache 是否可用 // 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil { if s.refreshTokenCache == nil {
return nil, ErrRefreshTokenInvalid return nil, ErrRefreshTokenInvalid
...@@ -1233,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) ...@@ -1233,7 +1239,14 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
} }
// 生成新的Token对,保持同一个家族ID // 生成新的Token对,保持同一个家族ID
return s.GenerateTokenPair(ctx, user, data.FamilyID) pair, err := s.GenerateTokenPair(ctx, user, data.FamilyID)
if err != nil {
return nil, err
}
return &TokenPairWithUser{
TokenPair: *pair,
UserRole: user.Role,
}, nil
} }
// RevokeRefreshToken 撤销单个Refresh Token // RevokeRefreshToken 撤销单个Refresh Token
......
...@@ -220,6 +220,9 @@ const ( ...@@ -220,6 +220,9 @@ const (
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403) // SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403)
SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling" SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling"
// SettingKeyBackendModeEnabled Backend 模式:禁用用户注册和自助服务,仅管理员可登录
SettingKeyBackendModeEnabled = "backend_mode_enabled"
) )
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
......
...@@ -65,6 +65,19 @@ const minVersionErrorTTL = 5 * time.Second ...@@ -65,6 +65,19 @@ const minVersionErrorTTL = 5 * time.Second
// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context // minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context
const minVersionDBTimeout = 5 * time.Second const minVersionDBTimeout = 5 * time.Second
// cachedBackendMode Backend Mode cache (in-process, 60s TTL)
type cachedBackendMode struct {
value bool
expiresAt int64 // unix nano
}
var backendModeCache atomic.Value // *cachedBackendMode
var backendModeSF singleflight.Group
const backendModeCacheTTL = 60 * time.Second
const backendModeErrorTTL = 5 * time.Second
const backendModeDBTimeout = 5 * time.Second
// DefaultSubscriptionGroupReader validates group references used by default subscriptions. // DefaultSubscriptionGroupReader validates group references used by default subscriptions.
type DefaultSubscriptionGroupReader interface { type DefaultSubscriptionGroupReader interface {
GetByID(ctx context.Context, id int64) (*Group, error) GetByID(ctx context.Context, id int64) (*Group, error)
...@@ -128,6 +141,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -128,6 +141,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySoraClientEnabled, SettingKeySoraClientEnabled,
SettingKeyCustomMenuItems, SettingKeyCustomMenuItems,
SettingKeyLinuxDoConnectEnabled, SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
...@@ -172,6 +186,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -172,6 +186,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomMenuItems: settings[SettingKeyCustomMenuItems],
LinuxDoOAuthEnabled: linuxDoEnabled, LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
}, nil }, nil
} }
...@@ -223,6 +238,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -223,6 +238,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomMenuItems json.RawMessage `json:"custom_menu_items"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
}{ }{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
...@@ -247,6 +263,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -247,6 +263,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
SoraClientEnabled: settings.SoraClientEnabled, SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
Version: s.version, Version: s.version,
}, nil }, nil
} }
...@@ -461,6 +478,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -461,6 +478,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// 分组隔离 // 分组隔离
updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling) updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling)
// Backend Mode
updates[SettingKeyBackendModeEnabled] = strconv.FormatBool(settings.BackendModeEnabled)
err = s.settingRepo.SetMultiple(ctx, updates) err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil { if err == nil {
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
...@@ -469,6 +489,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -469,6 +489,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
value: settings.MinClaudeCodeVersion, value: settings.MinClaudeCodeVersion,
expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(), expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(),
}) })
backendModeSF.Forget("backend_mode")
backendModeCache.Store(&cachedBackendMode{
value: settings.BackendModeEnabled,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
if s.onUpdate != nil { if s.onUpdate != nil {
s.onUpdate() // Invalidate cache after settings update s.onUpdate() // Invalidate cache after settings update
} }
...@@ -525,6 +550,52 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { ...@@ -525,6 +550,52 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
return value == "true" return value == "true"
} }
// IsBackendModeEnabled checks if backend mode is enabled
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path
func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.value
}
}
result, _, _ := backendModeSF.Do("backend_mode", func() (any, error) {
if cached, ok := backendModeCache.Load().(*cachedBackendMode); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.value, nil
}
}
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), backendModeDBTimeout)
defer cancel()
value, err := s.settingRepo.GetValue(dbCtx, SettingKeyBackendModeEnabled)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
// Setting not yet created (fresh install) - default to disabled with full TTL
backendModeCache.Store(&cachedBackendMode{
value: false,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
return false, nil
}
slog.Warn("failed to get backend_mode_enabled setting", "error", err)
backendModeCache.Store(&cachedBackendMode{
value: false,
expiresAt: time.Now().Add(backendModeErrorTTL).UnixNano(),
})
return false, nil
}
enabled := value == "true"
backendModeCache.Store(&cachedBackendMode{
value: enabled,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
return enabled, nil
})
if val, ok := result.(bool); ok {
return val
}
return false
}
// IsEmailVerifyEnabled 检查是否开启邮件验证 // IsEmailVerifyEnabled 检查是否开启邮件验证
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool { func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled) value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled)
...@@ -719,6 +790,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -719,6 +790,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomMenuItems: settings[SettingKeyCustomMenuItems],
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
} }
// 解析整数类型 // 解析整数类型
......
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type bmRepoStub struct {
getValueFn func(ctx context.Context, key string) (string, error)
calls int
}
func (s *bmRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *bmRepoStub) GetValue(ctx context.Context, key string) (string, error) {
s.calls++
if s.getValueFn == nil {
panic("unexpected GetValue call")
}
return s.getValueFn(ctx, key)
}
func (s *bmRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *bmRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *bmRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *bmRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *bmRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
type bmUpdateRepoStub struct {
updates map[string]string
getValueFn func(ctx context.Context, key string) (string, error)
}
func (s *bmUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *bmUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if s.getValueFn == nil {
panic("unexpected GetValue call")
}
return s.getValueFn(ctx, key)
}
func (s *bmUpdateRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *bmUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *bmUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
s.updates = make(map[string]string, len(settings))
for k, v := range settings {
s.updates[k] = v
}
return nil
}
func (s *bmUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *bmUpdateRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func resetBackendModeTestCache(t *testing.T) {
t.Helper()
backendModeCache.Store((*cachedBackendMode)(nil))
t.Cleanup(func() {
backendModeCache.Store((*cachedBackendMode)(nil))
})
}
func TestIsBackendModeEnabled_ReturnsTrue(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "true", nil
},
}
svc := NewSettingService(repo, &config.Config{})
require.True(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestIsBackendModeEnabled_ReturnsFalse(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "false", nil
},
}
svc := NewSettingService(repo, &config.Config{})
require.False(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestIsBackendModeEnabled_ReturnsFalseOnNotFound(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "", ErrSettingNotFound
},
}
svc := NewSettingService(repo, &config.Config{})
require.False(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestIsBackendModeEnabled_ReturnsFalseOnDBError(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "", errors.New("db down")
},
}
svc := NewSettingService(repo, &config.Config{})
require.False(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestIsBackendModeEnabled_CachesResult(t *testing.T) {
resetBackendModeTestCache(t)
repo := &bmRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "true", nil
},
}
svc := NewSettingService(repo, &config.Config{})
require.True(t, svc.IsBackendModeEnabled(context.Background()))
require.True(t, svc.IsBackendModeEnabled(context.Background()))
require.Equal(t, 1, repo.calls)
}
func TestUpdateSettings_InvalidatesBackendModeCache(t *testing.T) {
resetBackendModeTestCache(t)
backendModeCache.Store(&cachedBackendMode{
value: true,
expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(),
})
repo := &bmUpdateRepoStub{
getValueFn: func(ctx context.Context, key string) (string, error) {
require.Equal(t, SettingKeyBackendModeEnabled, key)
return "true", nil
},
}
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateSettings(context.Background(), &SystemSettings{
BackendModeEnabled: false,
})
require.NoError(t, err)
require.Equal(t, "false", repo.updates[SettingKeyBackendModeEnabled])
require.False(t, svc.IsBackendModeEnabled(context.Background()))
}
...@@ -69,6 +69,9 @@ type SystemSettings struct { ...@@ -69,6 +69,9 @@ type SystemSettings struct {
// 分组隔离:允许未分组 Key 调度(默认 false → 403) // 分组隔离:允许未分组 Key 调度(默认 false → 403)
AllowUngroupedKeyScheduling bool AllowUngroupedKeyScheduling bool
// Backend 模式:禁用用户注册和自助服务,仅管理员可登录
BackendModeEnabled bool
} }
type DefaultSubscriptionSetting struct { type DefaultSubscriptionSetting struct {
...@@ -101,6 +104,7 @@ type PublicSettings struct { ...@@ -101,6 +104,7 @@ type PublicSettings struct {
CustomMenuItems string // JSON array of custom menu items CustomMenuItems string // JSON array of custom menu items
LinuxDoOAuthEnabled bool LinuxDoOAuthEnabled bool
BackendModeEnabled bool
Version string Version string
} }
......
...@@ -40,6 +40,7 @@ export interface SystemSettings { ...@@ -40,6 +40,7 @@ export interface SystemSettings {
purchase_subscription_enabled: boolean purchase_subscription_enabled: boolean
purchase_subscription_url: string purchase_subscription_url: string
sora_client_enabled: boolean sora_client_enabled: boolean
backend_mode_enabled: boolean
custom_menu_items: CustomMenuItem[] custom_menu_items: CustomMenuItem[]
// SMTP settings // SMTP settings
smtp_host: string smtp_host: string
...@@ -106,6 +107,7 @@ export interface UpdateSettingsRequest { ...@@ -106,6 +107,7 @@ export interface UpdateSettingsRequest {
purchase_subscription_enabled?: boolean purchase_subscription_enabled?: boolean
purchase_subscription_url?: string purchase_subscription_url?: string
sora_client_enabled?: boolean sora_client_enabled?: boolean
backend_mode_enabled?: boolean
custom_menu_items?: CustomMenuItem[] custom_menu_items?: CustomMenuItem[]
smtp_host?: string smtp_host?: string
smtp_port?: number smtp_port?: number
......
...@@ -82,7 +82,7 @@ ...@@ -82,7 +82,7 @@
</template> </template>
<!-- Regular User View --> <!-- Regular User View -->
<template v-else> <template v-else-if="!appStore.backendModeEnabled">
<div class="sidebar-section"> <div class="sidebar-section">
<router-link <router-link
v-for="item in userNavItems" v-for="item in userNavItems"
......
...@@ -3922,6 +3922,9 @@ export default { ...@@ -3922,6 +3922,9 @@ export default {
site: { site: {
title: 'Site Settings', title: 'Site Settings',
description: 'Customize site branding', description: 'Customize site branding',
backendMode: 'Backend Mode',
backendModeDescription:
'Disables user registration, public site, and self-service features. Only admin can log in and manage the platform.',
siteName: 'Site Name', siteName: 'Site Name',
siteNamePlaceholder: 'Sub2API', siteNamePlaceholder: 'Sub2API',
siteNameHint: 'Displayed in emails and page titles', siteNameHint: 'Displayed in emails and page titles',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment