Commit 3a67002c authored by IanShaw027's avatar IanShaw027
Browse files

merge: 合并主分支改动并保留 ops 监控实现

合并 main 分支的最新改动到 ops 监控分支。
冲突解决策略:保留当前分支的 ops 相关改动,接受主分支的其他改动。

保留的 ops 改动:
- 运维监控配置和依赖注入
- 运维监控 API 处理器和中间件
- 运维监控服务层和数据访问层
- 运维监控前端界面和状态管理

接受的主分支改动:
- Linux DO OAuth 集成
- 账号过期功能
- IP 地址限制功能
- 用量统计优化
- 其他 bug 修复和功能改进
parents c48dc097 7d1fe818
...@@ -2,6 +2,7 @@ package admin ...@@ -2,6 +2,7 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
...@@ -63,10 +64,17 @@ type UpdateBalanceRequest struct { ...@@ -63,10 +64,17 @@ type UpdateBalanceRequest struct {
func (h *UserHandler) List(c *gin.Context) { func (h *UserHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
filters := service.UserListFilters{ filters := service.UserListFilters{
Status: c.Query("status"), Status: c.Query("status"),
Role: c.Query("role"), Role: c.Query("role"),
Search: c.Query("search"), Search: search,
Attributes: parseAttributeFilters(c), Attributes: parseAttributeFilters(c),
} }
......
...@@ -30,6 +30,8 @@ type CreateAPIKeyRequest struct { ...@@ -30,6 +30,8 @@ type CreateAPIKeyRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
GroupID *int64 `json:"group_id"` // nullable GroupID *int64 `json:"group_id"` // nullable
CustomKey *string `json:"custom_key"` // 可选的自定义key CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
} }
// UpdateAPIKeyRequest represents the update API key request payload // UpdateAPIKeyRequest represents the update API key request payload
...@@ -37,6 +39,8 @@ type UpdateAPIKeyRequest struct { ...@@ -37,6 +39,8 @@ type UpdateAPIKeyRequest struct {
Name string `json:"name"` Name string `json:"name"`
GroupID *int64 `json:"group_id"` GroupID *int64 `json:"group_id"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
} }
// List handles listing user's API keys with pagination // List handles listing user's API keys with pagination
...@@ -113,6 +117,8 @@ func (h *APIKeyHandler) Create(c *gin.Context) { ...@@ -113,6 +117,8 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
Name: req.Name, Name: req.Name,
GroupID: req.GroupID, GroupID: req.GroupID,
CustomKey: req.CustomKey, CustomKey: req.CustomKey,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
} }
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
...@@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) { ...@@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return return
} }
svcReq := service.UpdateAPIKeyRequest{} svcReq := service.UpdateAPIKeyRequest{
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
}
if req.Name != "" { if req.Name != "" {
svcReq.Name = &req.Name svcReq.Name = &req.Name
} }
......
...@@ -15,14 +15,16 @@ type AuthHandler struct { ...@@ -15,14 +15,16 @@ type AuthHandler struct {
cfg *config.Config cfg *config.Config
authService *service.AuthService authService *service.AuthService
userService *service.UserService userService *service.UserService
settingSvc *service.SettingService
} }
// NewAuthHandler creates a new AuthHandler // NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler { func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
cfg: cfg, cfg: cfg,
authService: authService, authService: authService,
userService: userService, userService: userService,
settingSvc: settingService,
} }
} }
......
package handler
import (
"context"
"encoding/base64"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
)
const (
linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
linuxDoOAuthDefaultRedirectTo = "/dashboard"
linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
linuxDoOAuthMaxRedirectLen = 2048
linuxDoOAuthMaxFragmentValueLen = 512
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
)
type linuxDoTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
type linuxDoTokenExchangeError struct {
StatusCode int
ProviderError string
ProviderDescription string
Body string
}
func (e *linuxDoTokenExchangeError) Error() string {
if e == nil {
return ""
}
parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)}
if strings.TrimSpace(e.ProviderError) != "" {
parts = append(parts, "error="+strings.TrimSpace(e.ProviderError))
}
if strings.TrimSpace(e.ProviderDescription) != "" {
parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription))
}
return strings.Join(parts, " ")
}
// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。
// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard
func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
cfg, err := h.getLinuxDoOAuthConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
state, err := oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
return
}
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
secureCookie := isRequestHTTPS(c)
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
codeChallenge := ""
if cfg.UsePKCE {
verifier, err := oauth.GenerateCodeVerifier()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
return
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured"))
return
}
authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
return
}
c.Redirect(http.StatusFound, authURL)
}
// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。
// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=...
func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
cfg, cfgErr := h.getLinuxDoOAuthConfig(c.Request.Context())
if cfgErr != nil {
response.ErrorFrom(c, cfgErr)
return
}
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
if frontendCallback == "" {
frontendCallback = linuxDoOAuthDefaultFrontendCB
}
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
return
}
code := strings.TrimSpace(c.Query("code"))
state := strings.TrimSpace(c.Query("state"))
if code == "" || state == "" {
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
return
}
secureCookie := isRequestHTTPS(c)
defer func() {
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
}()
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
if err != nil || expectedState == "" || state != expectedState {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
return
}
redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie)
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
codeVerifier := ""
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "")
return
}
tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier)
if err != nil {
description := ""
var exchangeErr *linuxDoTokenExchangeError
if errors.As(err, &exchangeErr) && exchangeErr != nil {
log.Printf(
"[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s",
exchangeErr.StatusCode,
exchangeErr.ProviderError,
exchangeErr.ProviderDescription,
truncateLogValue(exchangeErr.Body, 2048),
)
description = exchangeErr.Error()
} else {
log.Printf("[LinuxDo OAuth] token exchange failed: %v", err)
description = err.Error()
}
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description))
return
}
email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
return
}
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
if subject != "" {
email = linuxDoSyntheticEmail(subject)
}
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username)
if err != nil {
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
fragment := url.Values{}
fragment.Set("access_token", jwtToken)
fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
}
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
if h != nil && h.settingSvc != nil {
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
}
if h == nil || h.cfg == nil {
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
if !h.cfg.LinuxDo.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
return h.cfg.LinuxDo, nil
}
func linuxDoExchangeCode(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
code string,
redirectURI string,
codeVerifier string,
) (*linuxDoTokenResponse, error) {
client := req.C().SetTimeout(30 * time.Second)
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
if cfg.UsePKCE {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json")
switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) {
case "", "client_secret_post":
form.Set("client_secret", cfg.ClientSecret)
case "client_secret_basic":
r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
case "none":
default:
return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod)
}
resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL)
if err != nil {
return nil, fmt.Errorf("request token: %w", err)
}
body := strings.TrimSpace(resp.String())
if !resp.IsSuccessState() {
providerErr, providerDesc := parseOAuthProviderError(body)
return nil, &linuxDoTokenExchangeError{
StatusCode: resp.StatusCode,
ProviderError: providerErr,
ProviderDescription: providerDesc,
Body: body,
}
}
tokenResp, ok := parseLinuxDoTokenResponse(body)
if !ok || strings.TrimSpace(tokenResp.AccessToken) == "" {
return nil, &linuxDoTokenExchangeError{
StatusCode: resp.StatusCode,
Body: body,
}
}
if strings.TrimSpace(tokenResp.TokenType) == "" {
tokenResp.TokenType = "Bearer"
}
return tokenResp, nil
}
func linuxDoFetchUserInfo(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
token *linuxDoTokenResponse,
) (email string, username string, subject string, err error) {
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json").
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
return "", "", "", fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return linuxDoParseUserInfo(resp.String(), cfg)
}
func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
getGJSON(body, "user.email"),
getGJSON(body, "data.email"),
getGJSON(body, "attributes.email"),
)
username = firstNonEmpty(
getGJSON(body, cfg.UserInfoUsernamePath),
getGJSON(body, "username"),
getGJSON(body, "preferred_username"),
getGJSON(body, "name"),
getGJSON(body, "user.username"),
getGJSON(body, "user.name"),
)
subject = firstNonEmpty(
getGJSON(body, cfg.UserInfoIDPath),
getGJSON(body, "sub"),
getGJSON(body, "id"),
getGJSON(body, "user_id"),
getGJSON(body, "uid"),
getGJSON(body, "user.id"),
)
subject = strings.TrimSpace(subject)
if subject == "" {
return "", "", "", errors.New("userinfo missing id field")
}
if !isSafeLinuxDoSubject(subject) {
return "", "", "", errors.New("userinfo returned invalid id field")
}
email = strings.TrimSpace(email)
if email == "" {
// LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。
email = linuxDoSyntheticEmail(subject)
}
username = strings.TrimSpace(username)
if username == "" {
username = "linuxdo_" + subject
}
return email, username, subject, nil
}
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
u, err := url.Parse(cfg.AuthorizeURL)
if err != nil {
return "", fmt.Errorf("parse authorize_url: %w", err)
}
q := u.Query()
q.Set("response_type", "code")
q.Set("client_id", cfg.ClientID)
q.Set("redirect_uri", redirectURI)
if strings.TrimSpace(cfg.Scopes) != "" {
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
if cfg.UsePKCE {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil
}
func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) {
fragment := url.Values{}
fragment.Set("error", truncateFragmentValue(code))
if strings.TrimSpace(message) != "" {
fragment.Set("error_message", truncateFragmentValue(message))
}
if strings.TrimSpace(description) != "" {
fragment.Set("error_description", truncateFragmentValue(description))
}
redirectWithFragment(c, frontendCallback, fragment)
}
func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) {
u, err := url.Parse(frontendCallback)
if err != nil {
// 兜底:尽力跳转到默认页面,避免卡死在回调页。
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
return
}
if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
return
}
u.Fragment = fragment.Encode()
c.Header("Cache-Control", "no-store")
c.Header("Pragma", "no-cache")
c.Redirect(http.StatusFound, u.String())
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
v = strings.TrimSpace(v)
if v != "" {
return v
}
}
return ""
}
func parseOAuthProviderError(body string) (providerErr string, providerDesc string) {
body = strings.TrimSpace(body)
if body == "" {
return "", ""
}
providerErr = firstNonEmpty(
getGJSON(body, "error"),
getGJSON(body, "code"),
getGJSON(body, "error.code"),
)
providerDesc = firstNonEmpty(
getGJSON(body, "error_description"),
getGJSON(body, "error.message"),
getGJSON(body, "message"),
getGJSON(body, "detail"),
)
if providerErr != "" || providerDesc != "" {
return providerErr, providerDesc
}
values, err := url.ParseQuery(body)
if err != nil {
return "", ""
}
providerErr = firstNonEmpty(values.Get("error"), values.Get("code"))
providerDesc = firstNonEmpty(values.Get("error_description"), values.Get("error_message"), values.Get("message"))
return providerErr, providerDesc
}
func parseLinuxDoTokenResponse(body string) (*linuxDoTokenResponse, bool) {
body = strings.TrimSpace(body)
if body == "" {
return nil, false
}
accessToken := strings.TrimSpace(getGJSON(body, "access_token"))
if accessToken != "" {
tokenType := strings.TrimSpace(getGJSON(body, "token_type"))
refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token"))
scope := strings.TrimSpace(getGJSON(body, "scope"))
expiresIn := gjson.Get(body, "expires_in").Int()
return &linuxDoTokenResponse{
AccessToken: accessToken,
TokenType: tokenType,
ExpiresIn: expiresIn,
RefreshToken: refreshToken,
Scope: scope,
}, true
}
values, err := url.ParseQuery(body)
if err != nil {
return nil, false
}
accessToken = strings.TrimSpace(values.Get("access_token"))
if accessToken == "" {
return nil, false
}
expiresIn := int64(0)
if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" {
if v, err := strconv.ParseInt(raw, 10, 64); err == nil {
expiresIn = v
}
}
return &linuxDoTokenResponse{
AccessToken: accessToken,
TokenType: strings.TrimSpace(values.Get("token_type")),
ExpiresIn: expiresIn,
RefreshToken: strings.TrimSpace(values.Get("refresh_token")),
Scope: strings.TrimSpace(values.Get("scope")),
}, true
}
func getGJSON(body string, path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
res := gjson.Get(body, path)
if !res.Exists() {
return ""
}
return res.String()
}
func truncateLogValue(value string, maxLen int) string {
value = strings.TrimSpace(value)
if value == "" || maxLen <= 0 {
return ""
}
if len(value) <= maxLen {
return value
}
value = value[:maxLen]
for !utf8.ValidString(value) {
value = value[:len(value)-1]
}
return value
}
func singleLine(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
return strings.Join(strings.Fields(value), " ")
}
func sanitizeFrontendRedirectPath(path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
if len(path) > linuxDoOAuthMaxRedirectLen {
return ""
}
// 只允许同源相对路径(避免开放重定向)。
if !strings.HasPrefix(path, "/") {
return ""
}
if strings.HasPrefix(path, "//") {
return ""
}
if strings.Contains(path, "://") {
return ""
}
if strings.ContainsAny(path, "\r\n") {
return ""
}
return path
}
func isRequestHTTPS(c *gin.Context) bool {
if c.Request.TLS != nil {
return true
}
proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")))
return proto == "https"
}
func encodeCookieValue(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func decodeCookieValue(value string) (string, error) {
raw, err := base64.RawURLEncoding.DecodeString(value)
if err != nil {
return "", err
}
return string(raw), nil
}
func readCookieDecoded(c *gin.Context, name string) (string, error) {
ck, err := c.Request.Cookie(name)
if err != nil {
return "", err
}
return decodeCookieValue(ck.Value)
}
func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: value,
Path: linuxDoOAuthCookiePath,
MaxAge: maxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func clearCookie(c *gin.Context, name string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: "",
Path: linuxDoOAuthCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func truncateFragmentValue(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
if len(value) > linuxDoOAuthMaxFragmentValueLen {
value = value[:linuxDoOAuthMaxFragmentValueLen]
for !utf8.ValidString(value) {
value = value[:len(value)-1]
}
}
return value
}
func buildBearerAuthorization(tokenType, accessToken string) (string, error) {
tokenType = strings.TrimSpace(tokenType)
if tokenType == "" {
tokenType = "Bearer"
}
if !strings.EqualFold(tokenType, "Bearer") {
return "", fmt.Errorf("unsupported token_type: %s", tokenType)
}
accessToken = strings.TrimSpace(accessToken)
if accessToken == "" {
return "", errors.New("missing access_token")
}
if strings.ContainsAny(accessToken, " \t\r\n") {
return "", errors.New("access_token contains whitespace")
}
return "Bearer " + accessToken, nil
}
func isSafeLinuxDoSubject(subject string) bool {
subject = strings.TrimSpace(subject)
if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen {
return false
}
for _, r := range subject {
switch {
case r >= '0' && r <= '9':
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r == '_' || r == '-':
default:
return false
}
}
return true
}
func linuxDoSyntheticEmail(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
return ""
}
return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain
}
package handler
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSanitizeFrontendRedirectPath(t *testing.T) {
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard"))
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard "))
require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard"))
require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com"))
require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com"))
require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo"))
long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen)
require.Equal(t, "", sanitizeFrontendRedirectPath(long))
}
func TestBuildBearerAuthorization(t *testing.T) {
auth, err := buildBearerAuthorization("", "token123")
require.NoError(t, err)
require.Equal(t, "Bearer token123", auth)
auth, err = buildBearerAuthorization("bearer", "token123")
require.NoError(t, err)
require.Equal(t, "Bearer token123", auth)
_, err = buildBearerAuthorization("MAC", "token123")
require.Error(t, err)
_, err = buildBearerAuthorization("Bearer", "token 123")
require.Error(t, err)
}
func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
cfg := config.LinuxDoConnectConfig{
UserInfoURL: "https://connect.linux.do/api/user",
}
email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "alice", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
}
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
cfg := config.LinuxDoConnectConfig{
UserInfoURL: "https://connect.linux.do/api/user",
}
email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "linuxdo_123", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
}
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
cfg := config.LinuxDoConnectConfig{
UserInfoURL: "https://connect.linux.do/api/user",
}
_, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
require.Error(t, err)
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
_, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
require.Error(t, err)
}
func TestParseOAuthProviderErrorJSON(t *testing.T) {
code, desc := parseOAuthProviderError(`{"error":"invalid_client","error_description":"bad secret"}`)
require.Equal(t, "invalid_client", code)
require.Equal(t, "bad secret", desc)
}
func TestParseOAuthProviderErrorForm(t *testing.T) {
code, desc := parseOAuthProviderError("error=invalid_request&error_description=Missing+code_verifier")
require.Equal(t, "invalid_request", code)
require.Equal(t, "Missing code_verifier", desc)
}
func TestParseLinuxDoTokenResponseJSON(t *testing.T) {
token, ok := parseLinuxDoTokenResponse(`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`)
require.True(t, ok)
require.Equal(t, "t1", token.AccessToken)
require.Equal(t, "Bearer", token.TokenType)
require.Equal(t, int64(3600), token.ExpiresIn)
require.Equal(t, "user", token.Scope)
}
func TestParseLinuxDoTokenResponseForm(t *testing.T) {
token, ok := parseLinuxDoTokenResponse("access_token=t2&token_type=bearer&expires_in=60")
require.True(t, ok)
require.Equal(t, "t2", token.AccessToken)
require.Equal(t, "bearer", token.TokenType)
require.Equal(t, int64(60), token.ExpiresIn)
}
func TestSingleLineStripsWhitespace(t *testing.T) {
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
require.Equal(t, "", singleLine("\n\t\r"))
}
// Package dto provides data transfer objects for HTTP handlers. // Package dto provides data transfer objects for HTTP handlers.
package dto package dto
import "github.com/Wei-Shaw/sub2api/internal/service" import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func UserFromServiceShallow(u *service.User) *User { func UserFromServiceShallow(u *service.User) *User {
if u == nil { if u == nil {
...@@ -55,6 +59,8 @@ func APIKeyFromService(k *service.APIKey) *APIKey { ...@@ -55,6 +59,8 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
Name: k.Name, Name: k.Name,
GroupID: k.GroupID, GroupID: k.GroupID,
Status: k.Status, Status: k.Status,
IPWhitelist: k.IPWhitelist,
IPBlacklist: k.IPBlacklist,
CreatedAt: k.CreatedAt, CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt, UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User), User: UserFromServiceShallow(k.User),
...@@ -81,6 +87,8 @@ func GroupFromServiceShallow(g *service.Group) *Group { ...@@ -81,6 +87,8 @@ func GroupFromServiceShallow(g *service.Group) *Group {
ImagePrice1K: g.ImagePrice1K, ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K, ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K, ImagePrice4K: g.ImagePrice4K,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount, AccountCount: g.AccountCount,
...@@ -120,6 +128,8 @@ func AccountFromServiceShallow(a *service.Account) *Account { ...@@ -120,6 +128,8 @@ func AccountFromServiceShallow(a *service.Account) *Account {
Status: a.Status, Status: a.Status,
ErrorMessage: a.ErrorMessage, ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt, LastUsedAt: a.LastUsedAt,
ExpiresAt: timeToUnixSeconds(a.ExpiresAt),
AutoPauseOnExpired: a.AutoPauseOnExpired,
CreatedAt: a.CreatedAt, CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt, UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable, Schedulable: a.Schedulable,
...@@ -157,6 +167,14 @@ func AccountFromService(a *service.Account) *Account { ...@@ -157,6 +167,14 @@ func AccountFromService(a *service.Account) *Account {
return out return out
} }
func timeToUnixSeconds(value *time.Time) *int64 {
if value == nil {
return nil
}
ts := value.Unix()
return &ts
}
func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup { func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup {
if ag == nil { if ag == nil {
return nil return nil
...@@ -220,11 +238,26 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode { ...@@ -220,11 +238,26 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
} }
} }
func UsageLogFromService(l *service.UsageLog) *UsageLog { // AccountSummaryFromService returns a minimal AccountSummary for usage log display.
// Only includes ID and Name - no sensitive fields like Credentials, Proxy, etc.
func AccountSummaryFromService(a *service.Account) *AccountSummary {
if a == nil {
return nil
}
return &AccountSummary{
ID: a.ID,
Name: a.Name,
}
}
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
// The account parameter allows caller to control what Account info is included.
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
if l == nil { if l == nil {
return nil return nil
} }
return &UsageLog{ result := &UsageLog{
ID: l.ID, ID: l.ID,
UserID: l.UserID, UserID: l.UserID,
APIKeyID: l.APIKeyID, APIKeyID: l.APIKeyID,
...@@ -252,13 +285,34 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog { ...@@ -252,13 +285,34 @@ func UsageLogFromService(l *service.UsageLog) *UsageLog {
FirstTokenMs: l.FirstTokenMs, FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount, ImageCount: l.ImageCount,
ImageSize: l.ImageSize, ImageSize: l.ImageSize,
UserAgent: l.UserAgent,
CreatedAt: l.CreatedAt, CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User), User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey), APIKey: APIKeyFromService(l.APIKey),
Account: AccountFromService(l.Account), Account: account,
Group: GroupFromServiceShallow(l.Group), Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription), Subscription: UserSubscriptionFromService(l.Subscription),
} }
// IP 地址仅对管理员可见
if includeIPAddress {
result.IPAddress = l.IPAddress
}
return result
}
// UsageLogFromService converts a service UsageLog to DTO for regular users.
// It excludes Account details and IP address - users should not see these.
func UsageLogFromService(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, nil, false)
}
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
// It includes minimal Account info (ID, Name only) and IP address.
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
if l == nil {
return nil
}
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
} }
func SettingFromService(s *service.Setting) *Setting { func SettingFromService(s *service.Setting) *Setting {
......
...@@ -17,6 +17,11 @@ type SystemSettings struct { ...@@ -17,6 +17,11 @@ type SystemSettings struct {
TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"` TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
SiteName string `json:"site_name"` SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"` SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"` SiteSubtitle string `json:"site_subtitle"`
...@@ -56,5 +61,6 @@ type PublicSettings struct { ...@@ -56,5 +61,6 @@ type PublicSettings struct {
APIBaseURL string `json:"api_base_url"` APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"` ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"` DocURL string `json:"doc_url"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"` Version string `json:"version"`
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
APIBaseURL: settings.APIBaseURL, APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo, ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL, DocURL: settings.DocURL,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version, Version: h.version,
}) })
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment