Commit 1a641392 authored by cyhhao's avatar cyhhao
Browse files

Merge up/main

parents 36b817d0 24d19a5f
......@@ -8,6 +8,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
......@@ -26,6 +27,7 @@ type UserAttributeValueQuery struct {
predicates []predicate.UserAttributeValue
withUser *UserQuery
withDefinition *UserAttributeDefinitionQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -420,6 +422,9 @@ func (_q *UserAttributeValueQuery) sqlAll(ctx context.Context, hooks ...queryHoo
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -505,6 +510,9 @@ func (_q *UserAttributeValueQuery) loadDefinition(ctx context.Context, query *Us
func (_q *UserAttributeValueQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -573,6 +581,9 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -590,6 +601,32 @@ func (_q *UserAttributeValueQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserAttributeValueQuery) ForUpdate(opts ...sql.LockOption) *UserAttributeValueQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserAttributeValueQuery) ForShare(opts ...sql.LockOption) *UserAttributeValueQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserAttributeValueGroupBy is the group-by builder for UserAttributeValue entities.
type UserAttributeValueGroupBy struct {
selector
......
......@@ -9,6 +9,7 @@ import (
"math"
"entgo.io/ent"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
......@@ -30,6 +31,7 @@ type UserSubscriptionQuery struct {
withGroup *GroupQuery
withAssignedByUser *UserQuery
withUsageLogs *UsageLogQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
......@@ -494,6 +496,9 @@ func (_q *UserSubscriptionQuery) sqlAll(ctx context.Context, hooks ...queryHook)
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
for i := range hooks {
hooks[i](ctx, _spec)
}
......@@ -657,6 +662,9 @@ func (_q *UserSubscriptionQuery) loadUsageLogs(ctx context.Context, query *Usage
func (_q *UserSubscriptionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
if len(_q.modifiers) > 0 {
_spec.Modifiers = _q.modifiers
}
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
......@@ -728,6 +736,9 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, m := range _q.modifiers {
m(selector)
}
for _, p := range _q.predicates {
p(selector)
}
......@@ -745,6 +756,32 @@ func (_q *UserSubscriptionQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
// updated, deleted or "selected ... for update" by other sessions, until the transaction is
// either committed or rolled-back.
func (_q *UserSubscriptionQuery) ForUpdate(opts ...sql.LockOption) *UserSubscriptionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForUpdate(opts...)
})
return _q
}
// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
// on any rows that are read. Other sessions can read the rows, but cannot modify them
// until your transaction commits.
func (_q *UserSubscriptionQuery) ForShare(opts ...sql.LockOption) *UserSubscriptionQuery {
if _q.driver.Dialect() == dialect.Postgres {
_q.Unique(false)
}
_q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
s.ForShare(opts...)
})
return _q
}
// UserSubscriptionGroupBy is the group-by builder for UserSubscription entities.
type UserSubscriptionGroupBy struct {
selector
......
......@@ -6,6 +6,7 @@ import (
"encoding/hex"
"fmt"
"log"
"net/url"
"os"
"strings"
"time"
......@@ -43,6 +44,7 @@ type Config struct {
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
......@@ -322,6 +324,30 @@ type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。
//
// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。
// 这里是用于登录 Sub2API 本身的用户体系。
type LinuxDoConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback)
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"`
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"`
......@@ -388,6 +414,18 @@ func Load() (*Config, error) {
cfg.Server.Mode = "debug"
}
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID)
cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret)
cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL)
cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL)
cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL)
cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes)
cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL)
cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL)
cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod))
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
......@@ -426,6 +464,81 @@ func Load() (*Config, error) {
return &cfg, nil
}
// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。
func ValidateAbsoluteHTTPURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
// ValidateFrontendRedirectURL 校验前端回调地址:
// - 允许同源相对路径(以 / 开头)
// - 或绝对 http(s) URL(禁止 fragment)
func ValidateFrontendRedirectURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" {
return fmt.Errorf("empty url")
}
if strings.ContainsAny(raw, "\r\n") {
return fmt.Errorf("contains invalid characters")
}
if strings.HasPrefix(raw, "/") {
if strings.HasPrefix(raw, "//") {
return fmt.Errorf("must not start with //")
}
return nil
}
u, err := url.Parse(raw)
if err != nil {
return err
}
if !u.IsAbs() {
return fmt.Errorf("must be absolute http(s) url or relative path")
}
if !isHTTPScheme(u.Scheme) {
return fmt.Errorf("unsupported scheme: %s", u.Scheme)
}
if strings.TrimSpace(u.Host) == "" {
return fmt.Errorf("missing host")
}
if u.Fragment != "" {
return fmt.Errorf("must not include fragment")
}
return nil
}
func isHTTPScheme(scheme string) bool {
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
}
func warnIfInsecureURL(field, raw string) {
u, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
return
}
if strings.EqualFold(u.Scheme, "http") {
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
}
}
func setDefaults() {
viper.SetDefault("run_mode", RunModeStandard)
......@@ -475,6 +588,22 @@ func setDefaults() {
// Turnstile
viper.SetDefault("turnstile.required", false)
// LinuxDo Connect OAuth 登录(终端用户 SSO)
viper.SetDefault("linuxdo_connect.enabled", false)
viper.SetDefault("linuxdo_connect.client_id", "")
viper.SetDefault("linuxdo_connect.client_secret", "")
viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize")
viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token")
viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user")
viper.SetDefault("linuxdo_connect.scopes", "user")
viper.SetDefault("linuxdo_connect.redirect_url", "")
viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback")
viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post")
viper.SetDefault("linuxdo_connect.use_pkce", false)
viper.SetDefault("linuxdo_connect.userinfo_email_path", "")
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
// Database
viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", 5432)
......@@ -544,7 +673,7 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 10*1024*1024)
viper.SetDefault("gateway.max_line_size", 40*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
......@@ -586,6 +715,60 @@ func (c *Config) Validate() error {
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
}
if c.LinuxDo.Enabled {
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" {
return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.TokenURL) == "" {
return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" {
return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true")
}
if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" {
return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true")
}
method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod))
switch method {
case "", "client_secret_post", "client_secret_basic", "none":
default:
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
if method == "none" && !c.LinuxDo.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
}
if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" {
return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true")
}
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil {
return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err)
}
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil {
return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err)
}
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil {
return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err)
}
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil {
return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err)
}
if err := ValidateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil {
return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err)
}
warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL)
warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL)
warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL)
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
if c.Billing.CircuitBreaker.Enabled {
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
......
package config
import (
"strings"
"testing"
"time"
......@@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
t.Fatalf("ResponseHeaders.Enabled = true, want false")
}
}
func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.LinuxDo.Enabled = true
cfg.LinuxDo.ClientID = "test-client"
cfg.LinuxDo.ClientSecret = "test-secret"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
cfg.LinuxDo.UsePKCE = false
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error for javascript scheme, got nil")
}
if !strings.Contains(err.Error(), "linuxdo_connect.frontend_redirect_url") {
t.Fatalf("Validate() expected frontend_redirect_url error, got: %v", err)
}
}
func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
cfg.LinuxDo.Enabled = true
cfg.LinuxDo.ClientID = "test-client"
cfg.LinuxDo.ClientSecret = ""
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "none"
cfg.LinuxDo.UsePKCE = false
err = cfg.Validate()
if err == nil {
t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
}
if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
t.Fatalf("Validate() expected use_pkce error, got: %v", err)
}
}
......@@ -116,6 +116,7 @@ type BulkUpdateAccountsRequest struct {
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable *bool `json:"schedulable"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
......@@ -136,6 +137,11 @@ func (h *AccountHandler) List(c *gin.Context) {
accountType := c.Query("type")
status := c.Query("status")
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
if err != nil {
......@@ -655,6 +661,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
req.Concurrency != nil ||
req.Priority != nil ||
req.Status != "" ||
req.Schedulable != nil ||
req.GroupIDs != nil ||
len(req.Credentials) > 0 ||
len(req.Extra) > 0
......@@ -671,6 +678,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: req.Status,
Schedulable: req.Schedulable,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
......
......@@ -2,6 +2,7 @@ package admin
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
......@@ -67,6 +68,12 @@ func (h *GroupHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
platform := c.Query("platform")
status := c.Query("status")
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
isExclusiveStr := c.Query("is_exclusive")
var isExclusive *bool
......@@ -75,7 +82,7 @@ func (h *GroupHandler) List(c *gin.Context) {
isExclusive = &val
}
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive)
if err != nil {
response.ErrorFrom(c, err)
return
......
package admin
import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// PromoHandler handles admin promo code management
type PromoHandler struct {
promoService *service.PromoService
}
// NewPromoHandler creates a new admin promo handler
func NewPromoHandler(promoService *service.PromoService) *PromoHandler {
return &PromoHandler{
promoService: promoService,
}
}
// CreatePromoCodeRequest represents create promo code request
type CreatePromoCodeRequest struct {
Code string `json:"code"` // 可选,为空则自动生成
BonusAmount float64 `json:"bonus_amount" binding:"required,min=0"` // 赠送余额
MaxUses int `json:"max_uses" binding:"min=0"` // 最大使用次数,0=无限
ExpiresAt *int64 `json:"expires_at"` // 过期时间戳(秒)
Notes string `json:"notes"` // 备注
}
// UpdatePromoCodeRequest represents update promo code request
type UpdatePromoCodeRequest struct {
Code *string `json:"code"`
BonusAmount *float64 `json:"bonus_amount" binding:"omitempty,min=0"`
MaxUses *int `json:"max_uses" binding:"omitempty,min=0"`
Status *string `json:"status" binding:"omitempty,oneof=active disabled"`
ExpiresAt *int64 `json:"expires_at"`
Notes *string `json:"notes"`
}
// List handles listing all promo codes with pagination
// GET /api/v1/admin/promo-codes
func (h *PromoHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
if len(search) > 100 {
search = search[:100]
}
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.PromoCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.PromoCodeFromService(&codes[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
// GetByID handles getting a promo code by ID
// GET /api/v1/admin/promo-codes/:id
func (h *PromoHandler) GetByID(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
code, err := h.promoService.GetByID(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Create handles creating a new promo code
// POST /api/v1/admin/promo-codes
func (h *PromoHandler) Create(c *gin.Context) {
var req CreatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := &service.CreatePromoCodeInput{
Code: req.Code,
BonusAmount: req.BonusAmount,
MaxUses: req.MaxUses,
Notes: req.Notes,
}
if req.ExpiresAt != nil {
t := time.Unix(*req.ExpiresAt, 0)
input.ExpiresAt = &t
}
code, err := h.promoService.Create(c.Request.Context(), input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Update handles updating a promo code
// PUT /api/v1/admin/promo-codes/:id
func (h *PromoHandler) Update(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
var req UpdatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := &service.UpdatePromoCodeInput{
Code: req.Code,
BonusAmount: req.BonusAmount,
MaxUses: req.MaxUses,
Status: req.Status,
Notes: req.Notes,
}
if req.ExpiresAt != nil {
if *req.ExpiresAt == 0 {
// 0 表示清除过期时间
input.ExpiresAt = nil
} else {
t := time.Unix(*req.ExpiresAt, 0)
input.ExpiresAt = &t
}
}
code, err := h.promoService.Update(c.Request.Context(), codeID, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.PromoCodeFromService(code))
}
// Delete handles deleting a promo code
// DELETE /api/v1/admin/promo-codes/:id
func (h *PromoHandler) Delete(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
err = h.promoService.Delete(c.Request.Context(), codeID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Promo code deleted successfully"})
}
// GetUsages handles getting usage records for a promo code
// GET /api/v1/admin/promo-codes/:id/usages
func (h *PromoHandler) GetUsages(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid promo code ID")
return
}
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{
Page: page,
PageSize: pageSize,
}
usages, paginationResult, err := h.promoService.ListUsages(c.Request.Context(), codeID, params)
if err != nil {
response.ErrorFrom(c, err)
return
}
out := make([]dto.PromoCodeUsage, 0, len(usages))
for i := range usages {
out = append(out, *dto.PromoCodeUsageFromService(&usages[i]))
}
response.Paginated(c, out, paginationResult.Total, page, pageSize)
}
......@@ -51,6 +51,11 @@ func (h *ProxyHandler) List(c *gin.Context) {
protocol := c.Query("protocol")
status := c.Query("status")
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search)
if err != nil {
......
......@@ -5,6 +5,7 @@ import (
"encoding/csv"
"fmt"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
......@@ -41,6 +42,11 @@ func (h *RedeemHandler) List(c *gin.Context) {
codeType := c.Query("type")
status := c.Query("status")
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
if err != nil {
......
......@@ -2,8 +2,10 @@ package admin
import (
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
......@@ -50,6 +52,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
......@@ -88,6 +94,12 @@ type UpdateSettingsRequest struct {
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key"`
// LinuxDo Connect OAuth 登录(终端用户 SSO)
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
......@@ -165,6 +177,35 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// LinuxDo Connect 参数验证
if req.LinuxDoConnectEnabled {
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
req.LinuxDoConnectClientSecret = strings.TrimSpace(req.LinuxDoConnectClientSecret)
req.LinuxDoConnectRedirectURL = strings.TrimSpace(req.LinuxDoConnectRedirectURL)
if req.LinuxDoConnectClientID == "" {
response.BadRequest(c, "LinuxDo Client ID is required when enabled")
return
}
if req.LinuxDoConnectRedirectURL == "" {
response.BadRequest(c, "LinuxDo Redirect URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.LinuxDoConnectRedirectURL); err != nil {
response.BadRequest(c, "LinuxDo Redirect URL must be an absolute http(s) URL")
return
}
// 如果未提供 client_secret,则保留现有值(如有)。
if req.LinuxDoConnectClientSecret == "" {
if previousSettings.LinuxDoConnectClientSecret == "" {
response.BadRequest(c, "LinuxDo Client Secret is required when enabled")
return
}
req.LinuxDoConnectClientSecret = previousSettings.LinuxDoConnectClientSecret
}
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
......@@ -178,6 +219,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
......@@ -222,6 +267,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
......@@ -298,6 +347,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if req.TurnstileSecretKey != "" {
changed = append(changed, "turnstile_secret_key")
}
if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled {
changed = append(changed, "linuxdo_connect_enabled")
}
if before.LinuxDoConnectClientID != after.LinuxDoConnectClientID {
changed = append(changed, "linuxdo_connect_client_id")
}
if req.LinuxDoConnectClientSecret != "" {
changed = append(changed, "linuxdo_connect_client_secret")
}
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
changed = append(changed, "linuxdo_connect_redirect_url")
}
if before.SiteName != after.SiteName {
changed = append(changed, "site_name")
}
......@@ -337,6 +398,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
changed = append(changed, "fallback_model_antigravity")
}
if before.EnableIdentityPatch != after.EnableIdentityPatch {
changed = append(changed, "enable_identity_patch")
}
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
changed = append(changed, "identity_patch_prompt")
}
return changed
}
......
......@@ -2,6 +2,7 @@ package admin
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
......@@ -63,10 +64,17 @@ type UpdateBalanceRequest struct {
func (h *UserHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
filters := service.UserListFilters{
Status: c.Query("status"),
Role: c.Query("role"),
Search: c.Query("search"),
Search: search,
Attributes: parseAttributeFilters(c),
}
......
......@@ -30,6 +30,8 @@ type CreateAPIKeyRequest struct {
Name string `json:"name" binding:"required"`
GroupID *int64 `json:"group_id"` // nullable
CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
}
// UpdateAPIKeyRequest represents the update API key request payload
......@@ -37,6 +39,8 @@ type UpdateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
}
// List handles listing user's API keys with pagination
......@@ -113,6 +117,8 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
Name: req.Name,
GroupID: req.GroupID,
CustomKey: req.CustomKey,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
}
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
......@@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return
}
svcReq := service.UpdateAPIKeyRequest{}
svcReq := service.UpdateAPIKeyRequest{
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
}
if req.Name != "" {
svcReq.Name = &req.Name
}
......
......@@ -15,14 +15,18 @@ type AuthHandler struct {
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler {
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
return &AuthHandler{
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
}
}
......@@ -32,6 +36,7 @@ type RegisterRequest struct {
Password string `json:"password" binding:"required,min=6"`
VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码
}
// SendVerifyCodeRequest 发送验证码请求
......@@ -77,7 +82,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
}
}
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
if err != nil {
response.ErrorFrom(c, err)
return
......@@ -172,3 +177,63 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
}
// ValidatePromoCodeRequest 验证优惠码请求
type ValidatePromoCodeRequest struct {
Code string `json:"code" binding:"required"`
}
// ValidatePromoCodeResponse 验证优惠码响应
type ValidatePromoCodeResponse struct {
Valid bool `json:"valid"`
BonusAmount float64 `json:"bonus_amount,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
Message string `json:"message,omitempty"`
}
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
// POST /api/v1/auth/validate-promo-code
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
var req ValidatePromoCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
promoCode, err := h.promoService.ValidatePromoCode(c.Request.Context(), req.Code)
if err != nil {
// 根据错误类型返回对应的错误码
errorCode := "PROMO_CODE_INVALID"
switch err {
case service.ErrPromoCodeNotFound:
errorCode = "PROMO_CODE_NOT_FOUND"
case service.ErrPromoCodeExpired:
errorCode = "PROMO_CODE_EXPIRED"
case service.ErrPromoCodeDisabled:
errorCode = "PROMO_CODE_DISABLED"
case service.ErrPromoCodeMaxUsed:
errorCode = "PROMO_CODE_MAX_USED"
case service.ErrPromoCodeAlreadyUsed:
errorCode = "PROMO_CODE_ALREADY_USED"
}
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: errorCode,
})
return
}
if promoCode == nil {
response.Success(c, ValidatePromoCodeResponse{
Valid: false,
ErrorCode: "PROMO_CODE_INVALID",
})
return
}
response.Success(c, ValidatePromoCodeResponse{
Valid: true,
BonusAmount: promoCode.BonusAmount,
})
}
package handler
import (
"context"
"encoding/base64"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"github.com/tidwall/gjson"
)
const (
linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
linuxDoOAuthDefaultRedirectTo = "/dashboard"
linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
linuxDoOAuthMaxRedirectLen = 2048
linuxDoOAuthMaxFragmentValueLen = 512
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
)
type linuxDoTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
type linuxDoTokenExchangeError struct {
StatusCode int
ProviderError string
ProviderDescription string
Body string
}
func (e *linuxDoTokenExchangeError) Error() string {
if e == nil {
return ""
}
parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)}
if strings.TrimSpace(e.ProviderError) != "" {
parts = append(parts, "error="+strings.TrimSpace(e.ProviderError))
}
if strings.TrimSpace(e.ProviderDescription) != "" {
parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription))
}
return strings.Join(parts, " ")
}
// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。
// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard
func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
cfg, err := h.getLinuxDoOAuthConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
state, err := oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
return
}
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
secureCookie := isRequestHTTPS(c)
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
codeChallenge := ""
if cfg.UsePKCE {
verifier, err := oauth.GenerateCodeVerifier()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err))
return
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie)
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured"))
return
}
authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI)
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
return
}
c.Redirect(http.StatusFound, authURL)
}
// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。
// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=...
func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
cfg, cfgErr := h.getLinuxDoOAuthConfig(c.Request.Context())
if cfgErr != nil {
response.ErrorFrom(c, cfgErr)
return
}
frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
if frontendCallback == "" {
frontendCallback = linuxDoOAuthDefaultFrontendCB
}
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
return
}
code := strings.TrimSpace(c.Query("code"))
state := strings.TrimSpace(c.Query("state"))
if code == "" || state == "" {
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
return
}
secureCookie := isRequestHTTPS(c)
defer func() {
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
}()
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
if err != nil || expectedState == "" || state != expectedState {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
return
}
redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie)
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
codeVerifier := ""
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "")
return
}
tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier)
if err != nil {
description := ""
var exchangeErr *linuxDoTokenExchangeError
if errors.As(err, &exchangeErr) && exchangeErr != nil {
log.Printf(
"[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s",
exchangeErr.StatusCode,
exchangeErr.ProviderError,
exchangeErr.ProviderDescription,
truncateLogValue(exchangeErr.Body, 2048),
)
description = exchangeErr.Error()
} else {
log.Printf("[LinuxDo OAuth] token exchange failed: %v", err)
description = err.Error()
}
redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description))
return
}
email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
return
}
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
if subject != "" {
email = linuxDoSyntheticEmail(subject)
}
jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username)
if err != nil {
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
fragment := url.Values{}
fragment.Set("access_token", jwtToken)
fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
}
func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) {
if h != nil && h.settingSvc != nil {
return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx)
}
if h == nil || h.cfg == nil {
return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
}
if !h.cfg.LinuxDo.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
}
return h.cfg.LinuxDo, nil
}
func linuxDoExchangeCode(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
code string,
redirectURI string,
codeVerifier string,
) (*linuxDoTokenResponse, error) {
client := req.C().SetTimeout(30 * time.Second)
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
if cfg.UsePKCE {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json")
switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) {
case "", "client_secret_post":
form.Set("client_secret", cfg.ClientSecret)
case "client_secret_basic":
r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
case "none":
default:
return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod)
}
resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL)
if err != nil {
return nil, fmt.Errorf("request token: %w", err)
}
body := strings.TrimSpace(resp.String())
if !resp.IsSuccessState() {
providerErr, providerDesc := parseOAuthProviderError(body)
return nil, &linuxDoTokenExchangeError{
StatusCode: resp.StatusCode,
ProviderError: providerErr,
ProviderDescription: providerDesc,
Body: body,
}
}
tokenResp, ok := parseLinuxDoTokenResponse(body)
if !ok || strings.TrimSpace(tokenResp.AccessToken) == "" {
return nil, &linuxDoTokenExchangeError{
StatusCode: resp.StatusCode,
Body: body,
}
}
if strings.TrimSpace(tokenResp.TokenType) == "" {
tokenResp.TokenType = "Bearer"
}
return tokenResp, nil
}
func linuxDoFetchUserInfo(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
token *linuxDoTokenResponse,
) (email string, username string, subject string, err error) {
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
SetContext(ctx).
SetHeader("Accept", "application/json").
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
return "", "", "", fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return linuxDoParseUserInfo(resp.String(), cfg)
}
func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
getGJSON(body, "user.email"),
getGJSON(body, "data.email"),
getGJSON(body, "attributes.email"),
)
username = firstNonEmpty(
getGJSON(body, cfg.UserInfoUsernamePath),
getGJSON(body, "username"),
getGJSON(body, "preferred_username"),
getGJSON(body, "name"),
getGJSON(body, "user.username"),
getGJSON(body, "user.name"),
)
subject = firstNonEmpty(
getGJSON(body, cfg.UserInfoIDPath),
getGJSON(body, "sub"),
getGJSON(body, "id"),
getGJSON(body, "user_id"),
getGJSON(body, "uid"),
getGJSON(body, "user.id"),
)
subject = strings.TrimSpace(subject)
if subject == "" {
return "", "", "", errors.New("userinfo missing id field")
}
if !isSafeLinuxDoSubject(subject) {
return "", "", "", errors.New("userinfo returned invalid id field")
}
email = strings.TrimSpace(email)
if email == "" {
// LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。
email = linuxDoSyntheticEmail(subject)
}
username = strings.TrimSpace(username)
if username == "" {
username = "linuxdo_" + subject
}
return email, username, subject, nil
}
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
u, err := url.Parse(cfg.AuthorizeURL)
if err != nil {
return "", fmt.Errorf("parse authorize_url: %w", err)
}
q := u.Query()
q.Set("response_type", "code")
q.Set("client_id", cfg.ClientID)
q.Set("redirect_uri", redirectURI)
if strings.TrimSpace(cfg.Scopes) != "" {
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
if cfg.UsePKCE {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil
}
func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) {
fragment := url.Values{}
fragment.Set("error", truncateFragmentValue(code))
if strings.TrimSpace(message) != "" {
fragment.Set("error_message", truncateFragmentValue(message))
}
if strings.TrimSpace(description) != "" {
fragment.Set("error_description", truncateFragmentValue(description))
}
redirectWithFragment(c, frontendCallback, fragment)
}
func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) {
u, err := url.Parse(frontendCallback)
if err != nil {
// 兜底:尽力跳转到默认页面,避免卡死在回调页。
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
return
}
if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
return
}
u.Fragment = fragment.Encode()
c.Header("Cache-Control", "no-store")
c.Header("Pragma", "no-cache")
c.Redirect(http.StatusFound, u.String())
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
v = strings.TrimSpace(v)
if v != "" {
return v
}
}
return ""
}
func parseOAuthProviderError(body string) (providerErr string, providerDesc string) {
body = strings.TrimSpace(body)
if body == "" {
return "", ""
}
providerErr = firstNonEmpty(
getGJSON(body, "error"),
getGJSON(body, "code"),
getGJSON(body, "error.code"),
)
providerDesc = firstNonEmpty(
getGJSON(body, "error_description"),
getGJSON(body, "error.message"),
getGJSON(body, "message"),
getGJSON(body, "detail"),
)
if providerErr != "" || providerDesc != "" {
return providerErr, providerDesc
}
values, err := url.ParseQuery(body)
if err != nil {
return "", ""
}
providerErr = firstNonEmpty(values.Get("error"), values.Get("code"))
providerDesc = firstNonEmpty(values.Get("error_description"), values.Get("error_message"), values.Get("message"))
return providerErr, providerDesc
}
func parseLinuxDoTokenResponse(body string) (*linuxDoTokenResponse, bool) {
body = strings.TrimSpace(body)
if body == "" {
return nil, false
}
accessToken := strings.TrimSpace(getGJSON(body, "access_token"))
if accessToken != "" {
tokenType := strings.TrimSpace(getGJSON(body, "token_type"))
refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token"))
scope := strings.TrimSpace(getGJSON(body, "scope"))
expiresIn := gjson.Get(body, "expires_in").Int()
return &linuxDoTokenResponse{
AccessToken: accessToken,
TokenType: tokenType,
ExpiresIn: expiresIn,
RefreshToken: refreshToken,
Scope: scope,
}, true
}
values, err := url.ParseQuery(body)
if err != nil {
return nil, false
}
accessToken = strings.TrimSpace(values.Get("access_token"))
if accessToken == "" {
return nil, false
}
expiresIn := int64(0)
if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" {
if v, err := strconv.ParseInt(raw, 10, 64); err == nil {
expiresIn = v
}
}
return &linuxDoTokenResponse{
AccessToken: accessToken,
TokenType: strings.TrimSpace(values.Get("token_type")),
ExpiresIn: expiresIn,
RefreshToken: strings.TrimSpace(values.Get("refresh_token")),
Scope: strings.TrimSpace(values.Get("scope")),
}, true
}
func getGJSON(body string, path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
res := gjson.Get(body, path)
if !res.Exists() {
return ""
}
return res.String()
}
func truncateLogValue(value string, maxLen int) string {
value = strings.TrimSpace(value)
if value == "" || maxLen <= 0 {
return ""
}
if len(value) <= maxLen {
return value
}
value = value[:maxLen]
for !utf8.ValidString(value) {
value = value[:len(value)-1]
}
return value
}
func singleLine(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
return strings.Join(strings.Fields(value), " ")
}
func sanitizeFrontendRedirectPath(path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
if len(path) > linuxDoOAuthMaxRedirectLen {
return ""
}
// 只允许同源相对路径(避免开放重定向)。
if !strings.HasPrefix(path, "/") {
return ""
}
if strings.HasPrefix(path, "//") {
return ""
}
if strings.Contains(path, "://") {
return ""
}
if strings.ContainsAny(path, "\r\n") {
return ""
}
return path
}
func isRequestHTTPS(c *gin.Context) bool {
if c.Request.TLS != nil {
return true
}
proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")))
return proto == "https"
}
func encodeCookieValue(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
func decodeCookieValue(value string) (string, error) {
raw, err := base64.RawURLEncoding.DecodeString(value)
if err != nil {
return "", err
}
return string(raw), nil
}
func readCookieDecoded(c *gin.Context, name string) (string, error) {
ck, err := c.Request.Cookie(name)
if err != nil {
return "", err
}
return decodeCookieValue(ck.Value)
}
func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: value,
Path: linuxDoOAuthCookiePath,
MaxAge: maxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func clearCookie(c *gin.Context, name string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: name,
Value: "",
Path: linuxDoOAuthCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func truncateFragmentValue(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
if len(value) > linuxDoOAuthMaxFragmentValueLen {
value = value[:linuxDoOAuthMaxFragmentValueLen]
for !utf8.ValidString(value) {
value = value[:len(value)-1]
}
}
return value
}
func buildBearerAuthorization(tokenType, accessToken string) (string, error) {
tokenType = strings.TrimSpace(tokenType)
if tokenType == "" {
tokenType = "Bearer"
}
if !strings.EqualFold(tokenType, "Bearer") {
return "", fmt.Errorf("unsupported token_type: %s", tokenType)
}
accessToken = strings.TrimSpace(accessToken)
if accessToken == "" {
return "", errors.New("missing access_token")
}
if strings.ContainsAny(accessToken, " \t\r\n") {
return "", errors.New("access_token contains whitespace")
}
return "Bearer " + accessToken, nil
}
func isSafeLinuxDoSubject(subject string) bool {
subject = strings.TrimSpace(subject)
if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen {
return false
}
for _, r := range subject {
switch {
case r >= '0' && r <= '9':
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r == '_' || r == '-':
default:
return false
}
}
return true
}
func linuxDoSyntheticEmail(subject string) string {
subject = strings.TrimSpace(subject)
if subject == "" {
return ""
}
return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain
}
package handler
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestSanitizeFrontendRedirectPath(t *testing.T) {
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard"))
require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard "))
require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard"))
require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com"))
require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com"))
require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo"))
long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen)
require.Equal(t, "", sanitizeFrontendRedirectPath(long))
}
func TestBuildBearerAuthorization(t *testing.T) {
auth, err := buildBearerAuthorization("", "token123")
require.NoError(t, err)
require.Equal(t, "Bearer token123", auth)
auth, err = buildBearerAuthorization("bearer", "token123")
require.NoError(t, err)
require.Equal(t, "Bearer token123", auth)
_, err = buildBearerAuthorization("MAC", "token123")
require.Error(t, err)
_, err = buildBearerAuthorization("Bearer", "token 123")
require.Error(t, err)
}
func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
cfg := config.LinuxDoConnectConfig{
UserInfoURL: "https://connect.linux.do/api/user",
}
email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "alice", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
}
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
cfg := config.LinuxDoConnectConfig{
UserInfoURL: "https://connect.linux.do/api/user",
}
email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "linuxdo_123", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
}
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
cfg := config.LinuxDoConnectConfig{
UserInfoURL: "https://connect.linux.do/api/user",
}
_, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
require.Error(t, err)
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
_, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
require.Error(t, err)
}
func TestParseOAuthProviderErrorJSON(t *testing.T) {
code, desc := parseOAuthProviderError(`{"error":"invalid_client","error_description":"bad secret"}`)
require.Equal(t, "invalid_client", code)
require.Equal(t, "bad secret", desc)
}
func TestParseOAuthProviderErrorForm(t *testing.T) {
code, desc := parseOAuthProviderError("error=invalid_request&error_description=Missing+code_verifier")
require.Equal(t, "invalid_request", code)
require.Equal(t, "Missing code_verifier", desc)
}
func TestParseLinuxDoTokenResponseJSON(t *testing.T) {
token, ok := parseLinuxDoTokenResponse(`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`)
require.True(t, ok)
require.Equal(t, "t1", token.AccessToken)
require.Equal(t, "Bearer", token.TokenType)
require.Equal(t, int64(3600), token.ExpiresIn)
require.Equal(t, "user", token.Scope)
}
func TestParseLinuxDoTokenResponseForm(t *testing.T) {
token, ok := parseLinuxDoTokenResponse("access_token=t2&token_type=bearer&expires_in=60")
require.True(t, ok)
require.Equal(t, "t2", token.AccessToken)
require.Equal(t, "bearer", token.TokenType)
require.Equal(t, int64(60), token.ExpiresIn)
}
func TestSingleLineStripsWhitespace(t *testing.T) {
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
require.Equal(t, "", singleLine("\n\t\r"))
}
......@@ -59,6 +59,8 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
IPWhitelist: k.IPWhitelist,
IPBlacklist: k.IPBlacklist,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User),
......@@ -250,11 +252,12 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
// The account parameter allows caller to control what Account info is included.
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *UsageLog {
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
if l == nil {
return nil
}
return &UsageLog{
result := &UsageLog{
ID: l.ID,
UserID: l.UserID,
APIKeyID: l.APIKeyID,
......@@ -290,21 +293,26 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),
}
// IP 地址仅对管理员可见
if includeIPAddress {
result.IPAddress = l.IPAddress
}
return result
}
// UsageLogFromService converts a service UsageLog to DTO for regular users.
// It excludes Account details - users should not see account information.
// It excludes Account details and IP address - users should not see these.
func UsageLogFromService(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, nil)
return usageLogFromServiceBase(l, nil, false)
}
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
// It includes minimal Account info (ID, Name only).
// It includes minimal Account info (ID, Name only) and IP address.
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
if l == nil {
return nil
}
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account))
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
}
func SettingFromService(s *service.Setting) *Setting {
......@@ -362,3 +370,35 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
Errors: r.Errors,
}
}
func PromoCodeFromService(pc *service.PromoCode) *PromoCode {
if pc == nil {
return nil
}
return &PromoCode{
ID: pc.ID,
Code: pc.Code,
BonusAmount: pc.BonusAmount,
MaxUses: pc.MaxUses,
UsedCount: pc.UsedCount,
Status: pc.Status,
ExpiresAt: pc.ExpiresAt,
Notes: pc.Notes,
CreatedAt: pc.CreatedAt,
UpdatedAt: pc.UpdatedAt,
}
}
func PromoCodeUsageFromService(u *service.PromoCodeUsage) *PromoCodeUsage {
if u == nil {
return nil
}
return &PromoCodeUsage{
ID: u.ID,
PromoCodeID: u.PromoCodeID,
UserID: u.UserID,
BonusAmount: u.BonusAmount,
UsedAt: u.UsedAt,
User: UserFromServiceShallow(u.User),
}
}
......@@ -17,6 +17,11 @@ type SystemSettings struct {
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
......@@ -50,5 +55,6 @@ type PublicSettings struct {
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
}
......@@ -26,6 +26,8 @@ type APIKey struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"`
IPBlacklist []string `json:"ip_blacklist"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
......@@ -187,6 +189,9 @@ type UsageLog struct {
// User-Agent
UserAgent *string `json:"user_agent"`
// IP 地址(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
......@@ -245,3 +250,28 @@ type BulkAssignResult struct {
Subscriptions []UserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
}
// PromoCode 注册优惠码
type PromoCode struct {
ID int64 `json:"id"`
Code string `json:"code"`
BonusAmount float64 `json:"bonus_amount"`
MaxUses int `json:"max_uses"`
UsedCount int `json:"used_count"`
Status string `json:"status"`
ExpiresAt *time.Time `json:"expires_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// PromoCodeUsage 优惠码使用记录
type PromoCodeUsage struct {
ID int64 `json:"id"`
PromoCodeID int64 `json:"promo_code_id"`
UserID int64 `json:"user_id"`
BonusAmount float64 `json:"bonus_amount"`
UsedAt time.Time `json:"used_at"`
User *User `json:"user,omitempty"`
}
......@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -114,6 +115,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 获取 User-Agent
userAgent := c.Request.UserAgent()
// 获取客户端 IP
clientIP := ip.GetClientIP(c)
// 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
......@@ -273,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 异步记录使用量(subscription已在函数开头获取)
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
......@@ -283,10 +287,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: cip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent)
}(result, account, userAgent, clientIP)
return
}
}
......@@ -401,7 +406,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 异步记录使用量(subscription已在函数开头获取)
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
......@@ -411,10 +416,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: cip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent)
}(result, account, userAgent, clientIP)
return
}
}
......
......@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -167,6 +168,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 获取 User-Agent
userAgent := c.Request.UserAgent()
// 获取客户端 IP
clientIP := ip.GetClientIP(c)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
......@@ -307,7 +311,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 6) record usage async
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
......@@ -317,10 +321,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: cip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent)
}(result, account, userAgent, clientIP)
return
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment