Commit 3a67002c authored by IanShaw027's avatar IanShaw027
Browse files

merge: 合并主分支改动并保留 ops 监控实现

合并 main 分支的最新改动到 ops 监控分支。
冲突解决策略:保留当前分支的 ops 相关改动,接受主分支的其他改动。

保留的 ops 改动:
- 运维监控配置和依赖注入
- 运维监控 API 处理器和中间件
- 运维监控服务层和数据访问层
- 运维监控前端界面和状态管理

接受的主分支改动:
- Linux DO OAuth 集成
- 账号过期功能
- IP 地址限制功能
- 用量统计优化
- 其他 bug 修复和功能改进
parents c48dc097 7d1fe818
...@@ -2,6 +2,7 @@ package admin ...@@ -2,6 +2,7 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
...@@ -63,10 +64,17 @@ type UpdateBalanceRequest struct { ...@@ -63,10 +64,17 @@ type UpdateBalanceRequest struct {
func (h *UserHandler) List(c *gin.Context) { func (h *UserHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
filters := service.UserListFilters{ filters := service.UserListFilters{
Status: c.Query("status"), Status: c.Query("status"),
Role: c.Query("role"), Role: c.Query("role"),
Search: c.Query("search"), Search: search,
Attributes: parseAttributeFilters(c), Attributes: parseAttributeFilters(c),
} }
......
...@@ -27,16 +27,20 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler { ...@@ -27,16 +27,20 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
// CreateAPIKeyRequest represents the create API key request payload // CreateAPIKeyRequest represents the create API key request payload
type CreateAPIKeyRequest struct { type CreateAPIKeyRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
GroupID *int64 `json:"group_id"` // nullable GroupID *int64 `json:"group_id"` // nullable
CustomKey *string `json:"custom_key"` // 可选的自定义key CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
} }
// UpdateAPIKeyRequest represents the update API key request payload // UpdateAPIKeyRequest represents the update API key request payload
type UpdateAPIKeyRequest struct { type UpdateAPIKeyRequest struct {
Name string `json:"name"` Name string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
} }
// List handles listing user's API keys with pagination // List handles listing user's API keys with pagination
...@@ -110,9 +114,11 @@ func (h *APIKeyHandler) Create(c *gin.Context) { ...@@ -110,9 +114,11 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
} }
svcReq := service.CreateAPIKeyRequest{ svcReq := service.CreateAPIKeyRequest{
Name: req.Name, Name: req.Name,
GroupID: req.GroupID, GroupID: req.GroupID,
CustomKey: req.CustomKey, CustomKey: req.CustomKey,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
} }
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
...@@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) { ...@@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return return
} }
svcReq := service.UpdateAPIKeyRequest{} svcReq := service.UpdateAPIKeyRequest{
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
}
if req.Name != "" { if req.Name != "" {
svcReq.Name = &req.Name svcReq.Name = &req.Name
} }
......
...@@ -15,14 +15,16 @@ type AuthHandler struct { ...@@ -15,14 +15,16 @@ type AuthHandler struct {
cfg *config.Config cfg *config.Config
authService *service.AuthService authService *service.AuthService
userService *service.UserService userService *service.UserService
settingSvc *service.SettingService
} }
// NewAuthHandler creates a new AuthHandler // NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler { func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
cfg: cfg, cfg: cfg,
authService: authService, authService: authService,
userService: userService, userService: userService,
settingSvc: settingService,
} }
} }
......
package handler
import (
"context"
"encoding/base64"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
)
const (
linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
linuxDoOAuthDefaultRedirectTo = "/dashboard"
linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
linuxDoOAuthMaxRedirectLen = 2048
linuxDoOAuthMaxFragmentValueLen = 512
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
)
type linuxDoTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
type linuxDoTokenExchangeError struct {
StatusCode int
ProviderError string
ProviderDescription string
Body string
}
func (e *linuxDoTokenExchangeError) Error() string {
if e == nil {
return ""
}
parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)}
if strings.TrimSpace(e.ProviderError) != "" {
parts = append(parts, "error="+strings.TrimSpace(e.ProviderError))
}
if strings.TrimSpace(e.ProviderDescription) != "" {
parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription))
}
return strings.Join(parts, " ")
}
// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。
// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard
func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
cfg, err := h.getLinuxDoOAuthConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
state, err := oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
return
}
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
secureCookie := isRequestHTTPS(c)
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
codeChallenge := ""
if cfg.UsePKCE {
verifier, err := oauth.GenerateCodeVerifier()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
return
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured"))
return
}
authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
return
}
c.Redirect(http.StatusFound, authURL)
}
// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。
// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=...
func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
cfg, cfgErr := h.getLinuxDoOAuthConfig(c.Request.Context())
if cfgErr != nil {
response.ErrorFrom(c, cfgErr)
return
}
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
if frontendCallback == "" {
frontendCallback = linuxDoOAuthDefaultFrontendCB
}
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
return
}
code := strings.TrimSpace(c.Query("code"))
state := strings.TrimSpace(c.Query("state"))
if code == "" || state == "" {
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
return
}
secureCookie := isRequestHTTPS(c)
defer func() {
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
}()
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
if err != nil || expectedState == "" || state != expectedState {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
return
}
redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie)
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
codeVerifier := ""
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "")
return
}
tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier)
if err != nil {
description := ""
var exchangeErr *linuxDoTokenExchangeError
if errors.As(err, &exchangeErr) && exchangeErr != nil {
log.Printf(
"[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s",
exchangeErr.StatusCode,
exchangeErr.ProviderError,
exchangeErr.ProviderDescription,
truncateLogValue(exchangeErr.Body, 2048),
)
description = exchangeErr.Error()
} else {
log.Printf("[LinuxDo OAuth] token exchange failed: %v", err)
description = err.Error()
}
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description))
return
}
email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
return
}
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
if subject != "" {
email = linuxDoSyntheticEmail(subject)
}
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username)
if err != nil {
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
fragment := url.Values{}
fragment.Set("access_token", jwtToken)
fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
}
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
if h != nil && h.settingSvc != nil {
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
}
if h == nil || h.cfg == nil {
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
if !h.cfg.LinuxDo.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
return h.cfg.LinuxDo, nil
}
func linuxDoExchangeCode(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
code string,
redirectURI string,
codeVerifier string,
) (*linuxDoTokenResponse, error) {
client := req.C().SetTimeout(30 * time.Second)
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
if cfg.UsePKCE {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json")
switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) {
case "", "client_secret_post":
form.Set("client_secret", cfg.ClientSecret)
case "client_secret_basic":
r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
case "none":
default:
return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod)
}
resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL)
if err != nil {
return nil, fmt.Errorf("request token: %w", err)
}
body := strings.TrimSpace(resp.String())
if !resp.IsSuccessState() {
providerErr, providerDesc := parseOAuthProviderError(body)
return nil, &linuxDoTokenExchangeError{
StatusCode: resp.StatusCode,
ProviderError: providerErr,
ProviderDescription: providerDesc,
Body: body,
}
}
tokenResp, ok := parseLinuxDoTokenResponse(body)
if !ok || strings.TrimSpace(tokenResp.AccessToken) == "" {
return nil, &linuxDoTokenExchangeError{
StatusCode: resp.StatusCode,
Body: body,
}
}
if strings.TrimSpace(tokenResp.TokenType) == "" {
tokenResp.TokenType = "Bearer"
}
return tokenResp, nil
}
func linuxDoFetchUserInfo(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
token *linuxDoTokenResponse,
) (email string, username string, subject string, err error) {
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json").
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
return "", "", "", fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return linuxDoParseUserInfo(resp.String(), cfg)
}
func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
getGJSON(body, "user.email"),
getGJSON(body, "data.email"),
getGJSON(body, "attributes.email"),
)
username = firstNonEmpty(
getGJSON(body, cfg.UserInfoUsernamePath),
getGJSON(body, "username"),
getGJSON(body, "preferred_username"),
getGJSON(body, "name"),
getGJSON(body, "user.username"),
getGJSON(body, "user.name"),
)
subject = firstNonEmpty(
getGJSON(body, cfg.UserInfoIDPath),
getGJSON(body, "sub"),
getGJSON(body, "id"),
getGJSON(body, "user_id"),
getGJSON(body, "uid"),
getGJSON(body, "user.id"),
)
subject = strings.TrimSpace(subject)
if subject == "" {
return "", "", "", errors.New("userinfo missing id field")
}
if !isSafeLinuxDoSubject(subject) {
return "", "", "", errors.New("userinfo returned invalid id field")
}
email = strings.TrimSpace(email)
if email == "" {
// LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。
email = linuxDoSyntheticEmail(subject)
}
username = strings.TrimSpace(username)
if username == "" {
username = "linuxdo_" + subject
}
return email, username, subject, nil
}
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
u, err := url.Parse(cfg.AuthorizeURL)
if err != nil {
return "", fmt.Errorf("parse authorize_url: %w", err)
}
q := u.Query()
q.Set("response_type", "code")
q.Set("client_id", cfg.ClientID)
q.Set("redirect_uri", redirectURI)
if strings.TrimSpace(cfg.Scopes) != "" {
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
if cfg.UsePKCE {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil
}
func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) {
fragment := url.Values{}
fragment.Set("error", truncateFragmentValue(code))
if strings.TrimSpace(message) != "" {
fragment.Set("error_message", truncateFragmentValue(message))
}
if strings.TrimSpace(description) != "" {
fragment.Set("error_description", truncateFragmentValue(description))
}
redirectWithFragment(c, frontendCallback, fragment)
}
func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) {
u, err := url.Parse(frontendCallback)
if err != nil {
// 兜底:尽力跳转到默认页面,避免卡死在回调页。
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
return
}
if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
return
}
u.Fragment = fragment.Encode()
c.Header("Cache-Control", "no-store")
c.Header("Pragma", "no-cache")
c.Redirect(http.StatusFound, u.String())
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
v = strings.TrimSpace(v)
if v != "" {
return v
}
}
return ""
}
func parseOAuthProviderError(body string) (providerErr string, providerDesc string) {
body = strings.TrimSpace(body)
if body == "" {
return "", ""
}
providerErr = firstNonEmpty(
getGJSON(body, "error"),
getGJSON(body, "code"),
getGJSON(body, "error.code"),
)
providerDesc = firstNonEmpty(
getGJSON(body, "error_description"),
getGJSON(body, "error.message"),
getGJSON(body, "message"),
getGJSON(body, "detail"),
)
if providerErr != "" || providerDesc != "" {
return providerErr, providerDesc
}
values, err := url.ParseQuery(body)
if err != nil {
return "", ""
}
providerErr = firstNonEmpty(values.Get("error"), values.Get("code"))
providerDesc = firstNonEmpty(values.Get("error_description"), values.Get("error_message"), values.Get("message"))
return providerErr, providerDesc
}
func parseLinuxDoTokenResponse(body string) (*linuxDoTokenResponse, bool) {
body = strings.TrimSpace(body)
if body == "" {
return nil, false
}
accessToken := strings.TrimSpace(getGJSON(body, "access_token"))
if accessToken != "" {
tokenType := strings.TrimSpace(getGJSON(body, "token_type"))
refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token"))
scope := strings.TrimSpace(getGJSON(body, "scope"))
expiresIn := gjson.Get(body, "expires_in").Int()
return &linuxDoTokenResponse{
AccessToken: accessToken,
TokenType: tokenType,
ExpiresIn: expiresIn,
RefreshToken: refreshToken,
Scope: scope,
}, true
}
values, err := url.ParseQuery(body)
if err != nil {
return nil, false
}
accessToken = strings.TrimSpace(values.Get("access_token"))
if accessToken == "" {
return nil, false
}
expiresIn := int64(0)
if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" {
if v, err := strconv.ParseInt(raw, 10, 64); err == nil {
expiresIn = v
}
}
return &linuxDoTokenResponse{
AccessToken: accessToken,
TokenType: strings.TrimSpace(values.Get("token_type")),
ExpiresIn: expiresIn,
RefreshToken: strings.TrimSpace(values.Get("refresh_token")),
Scope: strings.TrimSpace(values.Get("scope")),
}, true
}
func getGJSON(body string, path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
res := gjson.Get(body, path)
if !res.Exists() {
return ""
}
return res.String()
}
func truncateLogValue(value string, maxLen int) string {
value = strings.TrimSpace(value)
if value == "" || maxLen <= 0 {
return ""
}
if len(value) <= maxLen {
return value
}
value = value[:maxLen]
for !utf8.ValidString(value) {
value = value[:len(value)-1]
}
return value
}
func singleLine(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
return strings.Join(strings.Fields(value), " ")
}
func sanitizeFrontendRedirectPath(path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
if len(path) > linuxDoOAuthMaxRedirectLen {
return ""
}
// 只允许同源相对路径(避免开放重定向)。
if !strings.HasPrefix(path, "/") {
return ""
}
if strings.HasPrefix(path, "//") {
return ""
}
if strings.Contains(path, "://") {
return ""
}
if strings.ContainsAny(path, "\r\n") {
return ""
}
return path
}
func isRequestHTTPS(c *gin.Context) bool {
if c.Request.TLS != nil {
return true
}
proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")))
return proto == "https"
}
func encodeCookieValue(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func decodeCookieValue(value string) (string, error) {
raw, err := base64.RawURLEncoding.DecodeString(value)
if err != nil {
return "", err
}
return string(raw), nil
}
func readCookieDecoded(c *gin.Context, name string) (string, error) {
ck, err := c.Request.Cookie(name)
if err != nil {
return "", err
}
return decodeCookieValue(ck.Value)
}
func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: value,
Path: linuxDoOAuthCookiePath,
MaxAge: maxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func clearCookie(c *gin.Context, name string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: "",
Path: linuxDoOAuthCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func truncateFragmentValue(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
if len(value) > linuxDoOAuthMaxFragmentValueLen {
value = value[:linuxDoOAuthMaxFragmentValueLen]
for !utf8.ValidString(value) {
value = value[:len(value)-1]
}
}
return value
}
func buildBearerAuthorization(tokenType, accessToken string) (string, error) {
tokenType = strings.TrimSpace(tokenType)
if tokenType == "" {
tokenType = "Bearer"
}
if !strings.EqualFold(tokenType, "Bearer") {
return "", fmt.Errorf("unsupported token_type: %s", tokenType)
}
accessToken = strings.TrimSpace(accessToken)
if accessToken == "" {
return "", errors.New("missing access_token")
}
if strings.ContainsAny(accessToken, " \t\r\n") {
return "", errors.New("access_token contains whitespace")
}
return "Bearer " + accessToken, nil
}
func isSafeLinuxDoSubject(subject string) bool {
subject = strings.TrimSpace(subject)
if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen {
return false
}
for _, r := range subject {
switch {
case r >= '0' && r <= '9':
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r == '_' || r == '-':
default:
return false
}
}
return true
}
func linuxDoSyntheticEmail(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
return ""
}
return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain
}
package handler
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSanitizeFrontendRedirectPath(t *testing.T) {
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard"))
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard "))
require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard"))
require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com"))
require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com"))
require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo"))
long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen)
require.Equal(t, "", sanitizeFrontendRedirectPath(long))
}
func TestBuildBearerAuthorization(t *testing.T) {
auth, err := buildBearerAuthorization("", "token123")
require.NoError(t, err)
require.Equal(t, "Bearer token123", auth)
auth, err = buildBearerAuthorization("bearer", "token123")
require.NoError(t, err)
require.Equal(t, "Bearer token123", auth)
_, err = buildBearerAuthorization("MAC", "token123")
require.Error(t, err)
_, err = buildBearerAuthorization("Bearer", "token 123")
require.Error(t, err)
}
func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
cfg := config.LinuxDoConnectConfig{
UserInfoURL: "https://connect.linux.do/api/user",
}
email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "alice", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
}
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
cfg := config.LinuxDoConnectConfig{
UserInfoURL: "https://connect.linux.do/api/user",
}
email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "linuxdo_123", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
}
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
cfg := config.LinuxDoConnectConfig{
UserInfoURL: "https://connect.linux.do/api/user",
}
_, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
require.Error(t, err)
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
_, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
require.Error(t, err)
}
func TestParseOAuthProviderErrorJSON(t *testing.T) {
code, desc := parseOAuthProviderError(`{"error":"invalid_client","error_description":"bad secret"}`)
require.Equal(t, "invalid_client", code)
require.Equal(t, "bad secret", desc)
}
func TestParseOAuthProviderErrorForm(t *testing.T) {
code, desc := parseOAuthProviderError("error=invalid_request&error_description=Missing+code_verifier")
require.Equal(t, "invalid_request", code)
require.Equal(t, "Missing code_verifier", desc)
}
func TestParseLinuxDoTokenResponseJSON(t *testing.T) {
token, ok := parseLinuxDoTokenResponse(`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`)
require.True(t, ok)
require.Equal(t, "t1", token.AccessToken)
require.Equal(t, "Bearer", token.TokenType)
require.Equal(t, int64(3600), token.ExpiresIn)
require.Equal(t, "user", token.Scope)
}
func TestParseLinuxDoTokenResponseForm(t *testing.T) {
token, ok := parseLinuxDoTokenResponse("access_token=t2&token_type=bearer&expires_in=60")
require.True(t, ok)
require.Equal(t, "t2", token.AccessToken)
require.Equal(t, "bearer", token.TokenType)
require.Equal(t, int64(60), token.ExpiresIn)
}
func TestSingleLineStripsWhitespace(t *testing.T) {
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
require.Equal(t, "", singleLine("\n\t\r"))
}
This diff is collapsed.
...@@ -17,6 +17,11 @@ type SystemSettings struct { ...@@ -17,6 +17,11 @@ type SystemSettings struct {
TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"` TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
SiteName string `json:"site_name"` SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"` SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"` SiteSubtitle string `json:"site_subtitle"`
...@@ -56,5 +61,6 @@ type PublicSettings struct { ...@@ -56,5 +61,6 @@ type PublicSettings struct {
APIBaseURL string `json:"api_base_url"` APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"` ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"` DocURL string `json:"doc_url"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"` Version string `json:"version"`
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
APIBaseURL: settings.APIBaseURL, APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo, ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL, DocURL: settings.DocURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version, Version: h.version,
}) })
} }
This diff is collapsed.
This diff is collapsed.
...@@ -27,10 +27,9 @@ const ( ...@@ -27,10 +27,9 @@ const (
// https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform). // https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform).
DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever" DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever"
// DefaultScopes for Google One (personal Google accounts with Gemini access) // DefaultGoogleOneScopes (DEPRECATED, no longer used)
// Only used when a custom OAuth client is configured. When using the built-in Gemini CLI client, // Google One now always uses the built-in Gemini CLI client with DefaultCodeAssistScopes.
// Google One uses DefaultCodeAssistScopes (same as code_assist) because the built-in client // This constant is kept for backward compatibility but is not actively used.
// cannot request restricted scopes like generative-language.retriever or drive.readonly.
DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile" DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
// GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth. // GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth.
......
...@@ -185,13 +185,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error ...@@ -185,13 +185,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
effective.Scopes = DefaultAIStudioScopes effective.Scopes = DefaultAIStudioScopes
} }
case "google_one": case "google_one":
// Google One uses built-in Gemini CLI client (same as code_assist) // Google One always uses built-in Gemini CLI client (same as code_assist)
// Built-in client can't request restricted scopes like generative-language.retriever // Built-in client can't request restricted scopes like generative-language.retriever or drive.readonly
if isBuiltinClient { effective.Scopes = DefaultCodeAssistScopes
effective.Scopes = DefaultCodeAssistScopes
} else {
effective.Scopes = DefaultGoogleOneScopes
}
default: default:
// Default to Code Assist scopes // Default to Code Assist scopes
effective.Scopes = DefaultCodeAssistScopes effective.Scopes = DefaultCodeAssistScopes
......
...@@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { ...@@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false, wantErr: false,
}, },
{ {
name: "Google One with custom client", name: "Google One always uses built-in client (even if custom credentials passed)",
input: OAuthConfig{ input: OAuthConfig{
ClientID: "custom-client-id", ClientID: "custom-client-id",
ClientSecret: "custom-client-secret", ClientSecret: "custom-client-secret",
}, },
oauthType: "google_one", oauthType: "google_one",
wantClientID: "custom-client-id", wantClientID: "custom-client-id",
wantScopes: DefaultGoogleOneScopes, wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client
wantErr: false, wantErr: false,
}, },
{ {
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment