Commit 0b746501 authored by 陈曦's avatar 陈曦
Browse files

1. merge upstream v0.1.113 2.提交migration相关文件

parents 45061102 be7551b9
...@@ -3,12 +3,14 @@ package service ...@@ -3,12 +3,14 @@ package service
import ( import (
"context" "context"
"fmt" "fmt"
"math"
"strconv" "strconv"
"strings" "strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
const ( const (
...@@ -21,6 +23,8 @@ const ( ...@@ -21,6 +23,8 @@ const (
SettingEnabledPaymentTypes = "ENABLED_PAYMENT_TYPES" SettingEnabledPaymentTypes = "ENABLED_PAYMENT_TYPES"
SettingLoadBalanceStrategy = "LOAD_BALANCE_STRATEGY" SettingLoadBalanceStrategy = "LOAD_BALANCE_STRATEGY"
SettingBalancePayDisabled = "BALANCE_PAYMENT_DISABLED" SettingBalancePayDisabled = "BALANCE_PAYMENT_DISABLED"
SettingBalanceRechargeMult = "BALANCE_RECHARGE_MULTIPLIER"
SettingRechargeFeeRate = "RECHARGE_FEE_RATE"
SettingProductNamePrefix = "PRODUCT_NAME_PREFIX" SettingProductNamePrefix = "PRODUCT_NAME_PREFIX"
SettingProductNameSuffix = "PRODUCT_NAME_SUFFIX" SettingProductNameSuffix = "PRODUCT_NAME_SUFFIX"
SettingHelpImageURL = "PAYMENT_HELP_IMAGE_URL" SettingHelpImageURL = "PAYMENT_HELP_IMAGE_URL"
...@@ -40,20 +44,22 @@ const ( ...@@ -40,20 +44,22 @@ const (
// PaymentConfig holds the payment system configuration. // PaymentConfig holds the payment system configuration.
type PaymentConfig struct { type PaymentConfig struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
MinAmount float64 `json:"min_amount"` MinAmount float64 `json:"min_amount"`
MaxAmount float64 `json:"max_amount"` MaxAmount float64 `json:"max_amount"`
DailyLimit float64 `json:"daily_limit"` DailyLimit float64 `json:"daily_limit"`
OrderTimeoutMin int `json:"order_timeout_minutes"` OrderTimeoutMin int `json:"order_timeout_minutes"`
MaxPendingOrders int `json:"max_pending_orders"` MaxPendingOrders int `json:"max_pending_orders"`
EnabledTypes []string `json:"enabled_payment_types"` EnabledTypes []string `json:"enabled_payment_types"`
BalanceDisabled bool `json:"balance_disabled"` BalanceDisabled bool `json:"balance_disabled"`
LoadBalanceStrategy string `json:"load_balance_strategy"` BalanceRechargeMultiplier float64 `json:"balance_recharge_multiplier"`
ProductNamePrefix string `json:"product_name_prefix"` RechargeFeeRate float64 `json:"recharge_fee_rate"`
ProductNameSuffix string `json:"product_name_suffix"` LoadBalanceStrategy string `json:"load_balance_strategy"`
HelpImageURL string `json:"help_image_url"` ProductNamePrefix string `json:"product_name_prefix"`
HelpText string `json:"help_text"` ProductNameSuffix string `json:"product_name_suffix"`
StripePublishableKey string `json:"stripe_publishable_key,omitempty"` HelpImageURL string `json:"help_image_url"`
HelpText string `json:"help_text"`
StripePublishableKey string `json:"stripe_publishable_key,omitempty"`
// Cancel rate limit settings // Cancel rate limit settings
CancelRateLimitEnabled bool `json:"cancel_rate_limit_enabled"` CancelRateLimitEnabled bool `json:"cancel_rate_limit_enabled"`
...@@ -65,19 +71,21 @@ type PaymentConfig struct { ...@@ -65,19 +71,21 @@ type PaymentConfig struct {
// UpdatePaymentConfigRequest contains fields to update payment configuration. // UpdatePaymentConfigRequest contains fields to update payment configuration.
type UpdatePaymentConfigRequest struct { type UpdatePaymentConfigRequest struct {
Enabled *bool `json:"enabled"` Enabled *bool `json:"enabled"`
MinAmount *float64 `json:"min_amount"` MinAmount *float64 `json:"min_amount"`
MaxAmount *float64 `json:"max_amount"` MaxAmount *float64 `json:"max_amount"`
DailyLimit *float64 `json:"daily_limit"` DailyLimit *float64 `json:"daily_limit"`
OrderTimeoutMin *int `json:"order_timeout_minutes"` OrderTimeoutMin *int `json:"order_timeout_minutes"`
MaxPendingOrders *int `json:"max_pending_orders"` MaxPendingOrders *int `json:"max_pending_orders"`
EnabledTypes []string `json:"enabled_payment_types"` EnabledTypes []string `json:"enabled_payment_types"`
BalanceDisabled *bool `json:"balance_disabled"` BalanceDisabled *bool `json:"balance_disabled"`
LoadBalanceStrategy *string `json:"load_balance_strategy"` BalanceRechargeMultiplier *float64 `json:"balance_recharge_multiplier"`
ProductNamePrefix *string `json:"product_name_prefix"` RechargeFeeRate *float64 `json:"recharge_fee_rate"`
ProductNameSuffix *string `json:"product_name_suffix"` LoadBalanceStrategy *string `json:"load_balance_strategy"`
HelpImageURL *string `json:"help_image_url"` ProductNamePrefix *string `json:"product_name_prefix"`
HelpText *string `json:"help_text"` ProductNameSuffix *string `json:"product_name_suffix"`
HelpImageURL *string `json:"help_image_url"`
HelpText *string `json:"help_text"`
// Cancel rate limit settings // Cancel rate limit settings
CancelRateLimitEnabled *bool `json:"cancel_rate_limit_enabled"` CancelRateLimitEnabled *bool `json:"cancel_rate_limit_enabled"`
...@@ -105,26 +113,28 @@ type MethodLimitsResponse struct { ...@@ -105,26 +113,28 @@ type MethodLimitsResponse struct {
} }
type CreateProviderInstanceRequest struct { type CreateProviderInstanceRequest struct {
ProviderKey string `json:"provider_key"` ProviderKey string `json:"provider_key"`
Name string `json:"name"` Name string `json:"name"`
Config map[string]string `json:"config"` Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"` SupportedTypes []string `json:"supported_types"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
PaymentMode string `json:"payment_mode"` PaymentMode string `json:"payment_mode"`
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
Limits string `json:"limits"` Limits string `json:"limits"`
RefundEnabled bool `json:"refund_enabled"` RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
} }
type UpdateProviderInstanceRequest struct { type UpdateProviderInstanceRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
Config map[string]string `json:"config"` Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"` SupportedTypes []string `json:"supported_types"`
Enabled *bool `json:"enabled"` Enabled *bool `json:"enabled"`
PaymentMode *string `json:"payment_mode"` PaymentMode *string `json:"payment_mode"`
SortOrder *int `json:"sort_order"` SortOrder *int `json:"sort_order"`
Limits *string `json:"limits"` Limits *string `json:"limits"`
RefundEnabled *bool `json:"refund_enabled"` RefundEnabled *bool `json:"refund_enabled"`
AllowUserRefund *bool `json:"allow_user_refund"`
} }
type CreatePlanRequest struct { type CreatePlanRequest struct {
GroupID int64 `json:"group_id"` GroupID int64 `json:"group_id"`
...@@ -181,7 +191,7 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo ...@@ -181,7 +191,7 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
keys := []string{ keys := []string{
SettingPaymentEnabled, SettingMinRechargeAmount, SettingMaxRechargeAmount, SettingPaymentEnabled, SettingMinRechargeAmount, SettingMaxRechargeAmount,
SettingDailyRechargeLimit, SettingOrderTimeoutMinutes, SettingMaxPendingOrders, SettingDailyRechargeLimit, SettingOrderTimeoutMinutes, SettingMaxPendingOrders,
SettingEnabledPaymentTypes, SettingBalancePayDisabled, SettingLoadBalanceStrategy, SettingEnabledPaymentTypes, SettingBalancePayDisabled, SettingBalanceRechargeMult, SettingRechargeFeeRate, SettingLoadBalanceStrategy,
SettingProductNamePrefix, SettingProductNameSuffix, SettingProductNamePrefix, SettingProductNameSuffix,
SettingHelpImageURL, SettingHelpText, SettingHelpImageURL, SettingHelpText,
SettingCancelRateLimitOn, SettingCancelRateLimitMax, SettingCancelRateLimitOn, SettingCancelRateLimitMax,
...@@ -199,18 +209,20 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo ...@@ -199,18 +209,20 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *PaymentConfig { func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *PaymentConfig {
cfg := &PaymentConfig{ cfg := &PaymentConfig{
Enabled: vals[SettingPaymentEnabled] == "true", Enabled: vals[SettingPaymentEnabled] == "true",
MinAmount: pcParseFloat(vals[SettingMinRechargeAmount], 1), MinAmount: pcParseFloat(vals[SettingMinRechargeAmount], 1),
MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 0), MaxAmount: pcParseFloat(vals[SettingMaxRechargeAmount], 0),
DailyLimit: pcParseFloat(vals[SettingDailyRechargeLimit], 0), DailyLimit: pcParseFloat(vals[SettingDailyRechargeLimit], 0),
OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], defaultOrderTimeoutMin), OrderTimeoutMin: pcParseInt(vals[SettingOrderTimeoutMinutes], defaultOrderTimeoutMin),
MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], defaultMaxPendingOrders), MaxPendingOrders: pcParseInt(vals[SettingMaxPendingOrders], defaultMaxPendingOrders),
BalanceDisabled: vals[SettingBalancePayDisabled] == "true", BalanceDisabled: vals[SettingBalancePayDisabled] == "true",
LoadBalanceStrategy: vals[SettingLoadBalanceStrategy], BalanceRechargeMultiplier: normalizeBalanceRechargeMultiplier(pcParseFloat(vals[SettingBalanceRechargeMult], defaultBalanceRechargeMultiplier)),
ProductNamePrefix: vals[SettingProductNamePrefix], RechargeFeeRate: pcParseFloat(vals[SettingRechargeFeeRate], 0),
ProductNameSuffix: vals[SettingProductNameSuffix], LoadBalanceStrategy: vals[SettingLoadBalanceStrategy],
HelpImageURL: vals[SettingHelpImageURL], ProductNamePrefix: vals[SettingProductNamePrefix],
HelpText: vals[SettingHelpText], ProductNameSuffix: vals[SettingProductNameSuffix],
HelpImageURL: vals[SettingHelpImageURL],
HelpText: vals[SettingHelpText],
CancelRateLimitEnabled: vals[SettingCancelRateLimitOn] == "true", CancelRateLimitEnabled: vals[SettingCancelRateLimitOn] == "true",
CancelRateLimitMax: pcParseInt(vals[SettingCancelRateLimitMax], 10), CancelRateLimitMax: pcParseInt(vals[SettingCancelRateLimitMax], 10),
...@@ -254,6 +266,21 @@ func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) stri ...@@ -254,6 +266,21 @@ func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) stri
// nil-check before serialisation — this is inherent to patch-style update patterns // nil-check before serialisation — this is inherent to patch-style update patterns
// and cannot be meaningfully decomposed without introducing unnecessary abstraction. // and cannot be meaningfully decomposed without introducing unnecessary abstraction.
func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req UpdatePaymentConfigRequest) error { func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req UpdatePaymentConfigRequest) error {
if req.BalanceRechargeMultiplier != nil {
if math.IsNaN(*req.BalanceRechargeMultiplier) || math.IsInf(*req.BalanceRechargeMultiplier, 0) || *req.BalanceRechargeMultiplier <= 0 {
return infraerrors.BadRequest("INVALID_BALANCE_RECHARGE_MULTIPLIER", "balance recharge multiplier must be greater than 0")
}
}
if req.RechargeFeeRate != nil {
v := *req.RechargeFeeRate
if math.IsNaN(v) || math.IsInf(v, 0) || v < 0 || v > 100 {
return infraerrors.BadRequest("INVALID_RECHARGE_FEE_RATE", "recharge fee rate must be between 0 and 100")
}
// Enforce max 2 decimal places
if math.Round(v*100) != v*100 {
return infraerrors.BadRequest("INVALID_RECHARGE_FEE_RATE", "recharge fee rate allows at most 2 decimal places")
}
}
m := map[string]string{ m := map[string]string{
SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled), SettingPaymentEnabled: formatBoolOrEmpty(req.Enabled),
SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount), SettingMinRechargeAmount: formatPositiveFloat(req.MinAmount),
...@@ -262,6 +289,8 @@ func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req Upda ...@@ -262,6 +289,8 @@ func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req Upda
SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin), SettingOrderTimeoutMinutes: formatPositiveInt(req.OrderTimeoutMin),
SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders), SettingMaxPendingOrders: formatPositiveInt(req.MaxPendingOrders),
SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled), SettingBalancePayDisabled: formatBoolOrEmpty(req.BalanceDisabled),
SettingBalanceRechargeMult: formatPositiveFloat(req.BalanceRechargeMultiplier),
SettingRechargeFeeRate: formatNonNegativeFloat(req.RechargeFeeRate),
SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy), SettingLoadBalanceStrategy: derefStr(req.LoadBalanceStrategy),
SettingProductNamePrefix: derefStr(req.ProductNamePrefix), SettingProductNamePrefix: derefStr(req.ProductNamePrefix),
SettingProductNameSuffix: derefStr(req.ProductNameSuffix), SettingProductNameSuffix: derefStr(req.ProductNameSuffix),
...@@ -295,6 +324,13 @@ func formatPositiveFloat(v *float64) string { ...@@ -295,6 +324,13 @@ func formatPositiveFloat(v *float64) string {
return strconv.FormatFloat(*v, 'f', 2, 64) return strconv.FormatFloat(*v, 'f', 2, 64)
} }
func formatNonNegativeFloat(v *float64) string {
if v == nil || *v < 0 {
return ""
}
return strconv.FormatFloat(*v, 'f', 2, 64)
}
func formatPositiveInt(v *int) string { func formatPositiveInt(v *int) string {
if v == nil || *v <= 0 { if v == nil || *v <= 0 {
return "" return ""
......
...@@ -216,7 +216,11 @@ func (s *PaymentService) markCompleted(ctx context.Context, o *dbent.PaymentOrde ...@@ -216,7 +216,11 @@ func (s *PaymentService) markCompleted(ctx context.Context, o *dbent.PaymentOrde
if err != nil { if err != nil {
return fmt.Errorf("mark completed: %w", err) return fmt.Errorf("mark completed: %w", err)
} }
s.writeAuditLog(ctx, o.ID, auditAction, "system", map[string]any{"rechargeCode": o.RechargeCode, "amount": o.Amount}) s.writeAuditLog(ctx, o.ID, auditAction, "system", map[string]any{
"rechargeCode": o.RechargeCode,
"creditedAmount": o.Amount,
"payAmount": o.PayAmount,
})
return nil return nil
} }
......
...@@ -43,18 +43,22 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest ...@@ -43,18 +43,22 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
if user.Status != payment.EntityStatusActive { if user.Status != payment.EntityStatusActive {
return nil, infraerrors.Forbidden("USER_INACTIVE", "user account is disabled") return nil, infraerrors.Forbidden("USER_INACTIVE", "user account is disabled")
} }
amount := req.Amount orderAmount := req.Amount
limitAmount := req.Amount
if plan != nil { if plan != nil {
amount = plan.Price orderAmount = plan.Price
limitAmount = plan.Price
} else if req.OrderType == payment.OrderTypeBalance {
orderAmount = calculateCreditedBalance(req.Amount, cfg.BalanceRechargeMultiplier)
} }
feeRate := s.getFeeRate(req.PaymentType) feeRate := cfg.RechargeFeeRate
payAmountStr := payment.CalculatePayAmount(amount, feeRate) payAmountStr := payment.CalculatePayAmount(limitAmount, feeRate)
payAmount, _ := strconv.ParseFloat(payAmountStr, 64) payAmount, _ := strconv.ParseFloat(payAmountStr, 64)
order, err := s.createOrderInTx(ctx, req, user, plan, cfg, amount, feeRate, payAmount) order, err := s.createOrderInTx(ctx, req, user, plan, cfg, orderAmount, limitAmount, feeRate, payAmount)
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp, err := s.invokeProvider(ctx, order, req, cfg, payAmountStr, payAmount, plan) resp, err := s.invokeProvider(ctx, order, req, cfg, limitAmount, payAmountStr, payAmount, plan)
if err != nil { if err != nil {
_, _ = s.entClient.PaymentOrder.UpdateOneID(order.ID). _, _ = s.entClient.PaymentOrder.UpdateOneID(order.ID).
SetStatus(OrderStatusFailed). SetStatus(OrderStatusFailed).
...@@ -99,7 +103,7 @@ func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRe ...@@ -99,7 +103,7 @@ func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRe
return plan, nil return plan, nil
} }
func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, amount, feeRate, payAmount float64) (*dbent.PaymentOrder, error) { func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderRequest, user *User, plan *dbent.SubscriptionPlan, cfg *PaymentConfig, orderAmount, limitAmount, feeRate, payAmount float64) (*dbent.PaymentOrder, error) {
tx, err := s.entClient.Tx(ctx) tx, err := s.entClient.Tx(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("begin transaction: %w", err) return nil, fmt.Errorf("begin transaction: %w", err)
...@@ -108,7 +112,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq ...@@ -108,7 +112,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
if err := s.checkPendingLimit(ctx, tx, req.UserID, cfg.MaxPendingOrders); err != nil { if err := s.checkPendingLimit(ctx, tx, req.UserID, cfg.MaxPendingOrders); err != nil {
return nil, err return nil, err
} }
if err := s.checkDailyLimit(ctx, tx, req.UserID, amount, cfg.DailyLimit); err != nil { if err := s.checkDailyLimit(ctx, tx, req.UserID, limitAmount, cfg.DailyLimit); err != nil {
return nil, err return nil, err
} }
tm := cfg.OrderTimeoutMin tm := cfg.OrderTimeoutMin
...@@ -121,7 +125,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq ...@@ -121,7 +125,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
SetUserEmail(user.Email). SetUserEmail(user.Email).
SetUserName(user.Username). SetUserName(user.Username).
SetNillableUserNotes(psNilIfEmpty(user.Notes)). SetNillableUserNotes(psNilIfEmpty(user.Notes)).
SetAmount(amount). SetAmount(orderAmount).
SetPayAmount(payAmount). SetPayAmount(payAmount).
SetFeeRate(feeRate). SetFeeRate(feeRate).
SetRechargeCode(""). SetRechargeCode("").
...@@ -180,6 +184,10 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user ...@@ -180,6 +184,10 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
} }
var used float64 var used float64
for _, o := range orders { for _, o := range orders {
if o.OrderType == payment.OrderTypeBalance {
used += o.PayAmount
continue
}
used += o.Amount used += o.Amount
} }
if used+amount > limit { if used+amount > limit {
...@@ -188,7 +196,7 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user ...@@ -188,7 +196,7 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
return nil return nil
} }
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) { func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) {
// Select an instance across all providers that support the requested payment type. // Select an instance across all providers that support the requested payment type.
// This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay"). // This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay").
sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
...@@ -202,7 +210,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen ...@@ -202,7 +210,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
if err != nil { if err != nil {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable") return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable")
} }
subject := s.buildPaymentSubject(plan, payAmountStr, cfg) subject := s.buildPaymentSubject(plan, limitAmount, cfg)
outTradeNo := order.OutTradeNo outTradeNo := order.OutTradeNo
pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes}) pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes})
if err != nil { if err != nil {
...@@ -213,23 +221,30 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen ...@@ -213,23 +221,30 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
if err != nil { if err != nil {
return nil, fmt.Errorf("update order with payment details: %w", err) return nil, fmt.Errorf("update order with payment details: %w", err)
} }
s.writeAuditLog(ctx, order.ID, "ORDER_CREATED", fmt.Sprintf("user:%d", req.UserID), map[string]any{"amount": req.Amount, "paymentType": req.PaymentType, "orderType": req.OrderType}) s.writeAuditLog(ctx, order.ID, "ORDER_CREATED", fmt.Sprintf("user:%d", req.UserID), map[string]any{
"paymentAmount": req.Amount,
"creditedAmount": order.Amount,
"payAmount": order.PayAmount,
"paymentType": req.PaymentType,
"orderType": req.OrderType,
})
return &CreateOrderResponse{OrderID: order.ID, Amount: order.Amount, PayAmount: payAmount, FeeRate: order.FeeRate, Status: OrderStatusPending, PaymentType: req.PaymentType, PayURL: pr.PayURL, QRCode: pr.QRCode, ClientSecret: pr.ClientSecret, ExpiresAt: order.ExpiresAt, PaymentMode: sel.PaymentMode}, nil return &CreateOrderResponse{OrderID: order.ID, Amount: order.Amount, PayAmount: payAmount, FeeRate: order.FeeRate, Status: OrderStatusPending, PaymentType: req.PaymentType, PayURL: pr.PayURL, QRCode: pr.QRCode, ClientSecret: pr.ClientSecret, ExpiresAt: order.ExpiresAt, PaymentMode: sel.PaymentMode}, nil
} }
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, payAmountStr string, cfg *PaymentConfig) string { func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig) string {
if plan != nil { if plan != nil {
if plan.ProductName != "" { if plan.ProductName != "" {
return plan.ProductName return plan.ProductName
} }
return "Sub2API Subscription " + plan.Name return "Sub2API Subscription " + plan.Name
} }
amountStr := strconv.FormatFloat(limitAmount, 'f', 2, 64)
pf := strings.TrimSpace(cfg.ProductNamePrefix) pf := strings.TrimSpace(cfg.ProductNamePrefix)
sf := strings.TrimSpace(cfg.ProductNameSuffix) sf := strings.TrimSpace(cfg.ProductNameSuffix)
if pf != "" || sf != "" { if pf != "" || sf != "" {
return strings.TrimSpace(pf + " " + payAmountStr + " " + sf) return strings.TrimSpace(pf + " " + amountStr + " " + sf)
} }
return "Sub2API " + payAmountStr + " CNY" return "Sub2API " + amountStr + " CNY"
} }
// --- Order Queries --- // --- Order Queries ---
......
...@@ -2,6 +2,7 @@ package service ...@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
...@@ -17,6 +18,19 @@ import ( ...@@ -17,6 +18,19 @@ import (
// --- Refund Flow --- // --- Refund Flow ---
// getOrderProviderInstance looks up the provider instance that processed this order.
// Returns nil, nil for legacy orders without provider_instance_id.
func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" {
return nil, nil
}
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if err != nil {
return nil, nil
}
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
}
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error { func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid) o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil { if err != nil {
...@@ -57,6 +71,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int ...@@ -57,6 +71,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
if o.Status != OrderStatusCompleted { if o.Status != OrderStatusCompleted {
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund") return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
} }
// Check provider instance allows user refund
inst, err := s.getOrderProviderInstance(ctx, o)
if err != nil || inst == nil {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order")
}
if !inst.AllowUserRefund {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "user refund is not enabled for this provider")
}
return o, nil return o, nil
} }
...@@ -69,6 +91,19 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float ...@@ -69,6 +91,19 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
if !psSliceContains(ok, o.Status) { if !psSliceContains(ok, o.Status) {
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund") return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
} }
// Check provider instance allows admin refund
inst, instErr := s.getOrderProviderInstance(ctx, o)
if instErr != nil {
slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr)
return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order")
}
if inst == nil {
// Legacy order without provider_instance_id — block refund
return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not available for this order")
}
if !inst.RefundEnabled {
return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider")
}
if math.IsNaN(amt) || math.IsInf(amt, 0) { if math.IsNaN(amt) || math.IsInf(amt, 0) {
return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount") return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount")
} }
...@@ -78,11 +113,7 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float ...@@ -78,11 +113,7 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
if amt-o.Amount > amountToleranceCNY { if amt-o.Amount > amountToleranceCNY {
return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge") return nil, nil, infraerrors.BadRequest("REFUND_AMOUNT_EXCEEDED", "refund amount exceeds recharge")
} }
// Full refund: use actual pay_amount for gateway (includes fees) ga := calculateGatewayRefundAmount(o.Amount, o.PayAmount, amt)
ga := amt
if math.Abs(amt-o.Amount) <= amountToleranceCNY {
ga = o.PayAmount
}
rr := strings.TrimSpace(reason) rr := strings.TrimSpace(reason)
if rr == "" && o.RefundRequestReason != nil { if rr == "" && o.RefundRequestReason != nil {
rr = *o.RefundRequestReason rr = *o.RefundRequestReason
...@@ -150,11 +181,17 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref ...@@ -150,11 +181,17 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") { if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
_, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct) _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct)
if err != nil { if err != nil {
// If deducting would expire the subscription, revoke it entirely if errors.Is(err, ErrAdjustWouldExpire) {
slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct) // Deduction would expire the subscription — revoke it entirely
if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil { slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct)
if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil {
s.restoreStatus(ctx, p)
return nil, fmt.Errorf("revoke subscription: %w", revokeErr)
}
} else {
// Other errors (DB failure, not found) — abort refund
s.restoreStatus(ctx, p) s.restoreStatus(ctx, p)
return nil, fmt.Errorf("revoke subscription: %w", revokeErr) return nil, fmt.Errorf("deduct subscription days: %w", err)
} }
} }
} else { } else {
......
...@@ -288,8 +288,6 @@ func psComputeValidityDays(days int, unit string) int { ...@@ -288,8 +288,6 @@ func psComputeValidityDays(days int, unit string) int {
} }
} }
func (s *PaymentService) getFeeRate(_ string) float64 { return 0 }
func psStartOfDayUTC(t time.Time) time.Time { func psStartOfDayUTC(t time.Time) time.Time {
y, m, d := t.UTC().Date() y, m, d := t.UTC().Date()
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC) return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
......
...@@ -102,6 +102,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t ...@@ -102,6 +102,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t
}) })
} }
// TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError
// OpenAI OAuth 401 缓存失效出错时仍走 temp_unschedulable
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) { func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
repo := &rateLimitAccountRepoStub{} repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")} invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
...@@ -109,7 +111,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin ...@@ -109,7 +111,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
service.SetTokenCacheInvalidator(invalidator) service.SetTokenCacheInvalidator(invalidator)
account := &Account{ account := &Account{
ID: 101, ID: 101,
Platform: PlatformGemini, Platform: PlatformOpenAI,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
} }
......
...@@ -99,13 +99,19 @@ type DefaultSubscriptionGroupReader interface { ...@@ -99,13 +99,19 @@ type DefaultSubscriptionGroupReader interface {
GetByID(ctx context.Context, id int64) (*Group, error) GetByID(ctx context.Context, id int64) (*Group, error)
} }
// WebSearchManagerBuilder creates a websearch.Manager from config (injected by infra layer).
// proxyURLs maps proxy ID to resolved URL for provider-level proxy support.
type WebSearchManagerBuilder func(cfg *WebSearchEmulationConfig, proxyURLs map[int64]string)
// SettingService 系统设置服务 // SettingService 系统设置服务
type SettingService struct { type SettingService struct {
settingRepo SettingRepository settingRepo SettingRepository
defaultSubGroupReader DefaultSubscriptionGroupReader defaultSubGroupReader DefaultSubscriptionGroupReader
cfg *config.Config proxyRepo ProxyRepository // for resolving websearch provider proxy URLs
onUpdate func() // Callback when settings are updated (for cache invalidation) cfg *config.Config
version string // Application version onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
webSearchManagerBuilder WebSearchManagerBuilder
} }
// NewSettingService 创建系统设置服务实例 // NewSettingService 创建系统设置服务实例
...@@ -121,6 +127,11 @@ func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscri ...@@ -121,6 +127,11 @@ func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscri
s.defaultSubGroupReader = reader s.defaultSubGroupReader = reader
} }
// SetProxyRepository injects a proxy repo for resolving websearch provider proxy URLs.
func (s *SettingService) SetProxyRepository(repo ProxyRepository) {
s.proxyRepo = repo
}
// GetAllSettings 获取所有系统设置 // GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) { func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx) settings, err := s.settingRepo.GetAll(ctx)
...@@ -168,9 +179,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -168,9 +179,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyCustomEndpoints, SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled, SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled, SettingKeyBackendModeEnabled,
SettingPaymentEnabled,
SettingKeyOIDCConnectEnabled, SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName, SettingKeyOIDCConnectProviderName,
SettingPaymentEnabled, SettingKeyBalanceLowNotifyEnabled,
SettingKeyBalanceLowNotifyThreshold,
SettingKeyBalanceLowNotifyRechargeURL,
SettingKeyAccountQuotaNotifyEnabled,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
...@@ -209,6 +224,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -209,6 +224,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
settings[SettingKeyTablePageSizeOptions], settings[SettingKeyTablePageSizeOptions],
) )
var balanceLowNotifyThreshold float64
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
balanceLowNotifyThreshold = v
}
return &PublicSettings{ return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled, EmailVerifyEnabled: emailVerifyEnabled,
...@@ -235,9 +255,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -235,9 +255,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomEndpoints: settings[SettingKeyCustomEndpoints], CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled, LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled, OIDCOAuthEnabled: oidcEnabled,
OIDCOAuthProviderName: oidcProviderName, OIDCOAuthProviderName: oidcProviderName,
PaymentEnabled: settings[SettingPaymentEnabled] == "true", BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true",
AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true",
BalanceLowNotifyThreshold: balanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings[SettingKeyBalanceLowNotifyRechargeURL],
}, nil }, nil
} }
...@@ -287,10 +311,14 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -287,10 +311,14 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints json.RawMessage `json:"custom_endpoints"` CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
}{ }{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
...@@ -317,10 +345,14 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -317,10 +345,14 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName, OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
PaymentEnabled: settings.PaymentEnabled,
Version: s.version, Version: s.version,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
}, nil }, nil
} }
...@@ -595,6 +627,13 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -595,6 +627,13 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough) updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning) updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
// Balance low notification
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64)
updates[SettingKeyBalanceLowNotifyRechargeURL] = settings.BalanceLowNotifyRechargeURL
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
err = s.settingRepo.SetMultiple(ctx, updates) err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil { if err == nil {
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
...@@ -1217,6 +1256,30 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -1217,6 +1256,30 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true" result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true" result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
// Web search emulation: quick enabled check from the JSON config
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
var wsCfg WebSearchEmulationConfig
if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil {
result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
}
}
// Balance low notification
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
result.BalanceLowNotifyThreshold = v
}
result.BalanceLowNotifyRechargeURL = settings[SettingKeyBalanceLowNotifyRechargeURL]
// Account quota notification
result.AccountQuotaNotifyEnabled = settings[SettingKeyAccountQuotaNotifyEnabled] == "true"
if raw := strings.TrimSpace(settings[SettingKeyAccountQuotaNotifyEmails]); raw != "" {
result.AccountQuotaNotifyEmails = ParseNotifyEmails(raw)
}
if result.AccountQuotaNotifyEmails == nil {
result.AccountQuotaNotifyEmails = []NotifyEmailEntry{}
}
return result return result
} }
......
...@@ -66,7 +66,7 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis ...@@ -66,7 +66,7 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis
func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) { func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) {
repo := &settingPublicRepoStub{ repo := &settingPublicRepoStub{
values: map[string]string{ values: map[string]string{
SettingKeyTableDefaultPageSize: "50", SettingKeyTableDefaultPageSize: "50",
SettingKeyTablePageSizeOptions: "[20,50,100]", SettingKeyTablePageSizeOptions: "[20,50,100]",
}, },
} }
......
...@@ -208,7 +208,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { ...@@ -208,7 +208,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
svc := NewSettingService(repo, &config.Config{}) svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateSettings(context.Background(), &SystemSettings{ err := svc.UpdateSettings(context.Background(), &SystemSettings{
TableDefaultPageSize: 50, TableDefaultPageSize: 50,
TablePageSizeOptions: []int{20, 50, 100}, TablePageSizeOptions: []int{20, 50, 100},
}) })
require.NoError(t, err) require.NoError(t, err)
...@@ -216,7 +216,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { ...@@ -216,7 +216,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
require.Equal(t, "[20,50,100]", repo.updates[SettingKeyTablePageSizeOptions]) require.Equal(t, "[20,50,100]", repo.updates[SettingKeyTablePageSizeOptions])
err = svc.UpdateSettings(context.Background(), &SystemSettings{ err = svc.UpdateSettings(context.Background(), &SystemSettings{
TableDefaultPageSize: 1000, TableDefaultPageSize: 1000,
TablePageSizeOptions: []int{20, 100}, TablePageSizeOptions: []int{20, 100},
}) })
require.NoError(t, err) require.NoError(t, err)
......
...@@ -106,6 +106,18 @@ type SystemSettings struct { ...@@ -106,6 +106,18 @@ type SystemSettings struct {
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true) EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false) EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
// Balance low notification
BalanceLowNotifyEnabled bool
BalanceLowNotifyThreshold float64
BalanceLowNotifyRechargeURL string
// Account quota notification
AccountQuotaNotifyEnabled bool
AccountQuotaNotifyEmails []NotifyEmailEntry
} }
type DefaultSubscriptionSetting struct { type DefaultSubscriptionSetting struct {
...@@ -141,10 +153,15 @@ type PublicSettings struct { ...@@ -141,10 +153,15 @@ type PublicSettings struct {
LinuxDoOAuthEnabled bool LinuxDoOAuthEnabled bool
BackendModeEnabled bool BackendModeEnabled bool
PaymentEnabled bool
OIDCOAuthEnabled bool OIDCOAuthEnabled bool
OIDCOAuthProviderName string OIDCOAuthProviderName string
PaymentEnabled bool
Version string Version string
BalanceLowNotifyEnabled bool
AccountQuotaNotifyEnabled bool
BalanceLowNotifyThreshold float64
BalanceLowNotifyRechargeURL string
} }
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
......
...@@ -100,9 +100,22 @@ func valueOrZero(v *int64) int64 { ...@@ -100,9 +100,22 @@ func valueOrZero(v *int64) int64 {
return *v return *v
} }
// AccountQuotaState holds the post-increment quota state returned by the DB transaction.
// All values are post-update (i.e., already include the increment).
type AccountQuotaState struct {
TotalUsed float64
TotalLimit float64
DailyUsed float64
DailyLimit float64
WeeklyUsed float64
WeeklyLimit float64
}
type UsageBillingApplyResult struct { type UsageBillingApplyResult struct {
Applied bool Applied bool
APIKeyQuotaExhausted bool APIKeyQuotaExhausted bool
NewBalance *float64 // post-deduction balance (nil = no balance deduction)
QuotaState *AccountQuotaState // post-increment quota state (nil = no quota increment)
} }
type UsageBillingRepository interface { type UsageBillingRepository interface {
......
...@@ -146,6 +146,8 @@ type UsageLog struct { ...@@ -146,6 +146,8 @@ type UsageLog struct {
RateMultiplier float64 RateMultiplier float64
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理) // AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
AccountRateMultiplier *float64 AccountRateMultiplier *float64
// AccountStatsCost 账号统计定价预计算费用(nil = 使用默认公式 total_cost × account_rate_multiplier)
AccountStatsCost *float64
BillingType int8 BillingType int8
RequestType RequestType RequestType RequestType
......
...@@ -30,6 +30,13 @@ type User struct { ...@@ -30,6 +30,13 @@ type User struct {
TotpEnabled bool // 是否启用 TOTP TotpEnabled bool // 是否启用 TOTP
TotpEnabledAt *time.Time // TOTP 启用时间 TotpEnabledAt *time.Time // TOTP 启用时间
// 余额不足通知
BalanceNotifyEnabled bool
BalanceNotifyThresholdType string // "fixed" (default) | "percentage"
BalanceNotifyThreshold *float64
BalanceNotifyExtraEmails []NotifyEmailEntry
TotalRecharged float64
APIKeys []APIKey APIKeys []APIKey
Subscriptions []UserSubscription Subscriptions []UserSubscription
} }
......
...@@ -2,8 +2,10 @@ package service ...@@ -2,8 +2,10 @@ package service
import ( import (
"context" "context"
"crypto/subtle"
"fmt" "fmt"
"log" "log/slog"
"strings"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
...@@ -11,9 +13,18 @@ import ( ...@@ -11,9 +13,18 @@ import (
) )
var ( var (
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
)
const (
maxNotifyEmails = 3 // Maximum number of notification emails per user
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
notifyCodeUserRateWindow = 10 * time.Minute
) )
// UserListFilters contains all filter options for listing users // UserListFilters contains all filter options for listing users
...@@ -58,9 +69,11 @@ type UserRepository interface { ...@@ -58,9 +69,11 @@ type UserRepository interface {
// UpdateProfileRequest 更新用户资料请求 // UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct { type UpdateProfileRequest struct {
Email *string `json:"email"` Email *string `json:"email"`
Username *string `json:"username"` Username *string `json:"username"`
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
} }
// ChangePasswordRequest 修改密码请求 // ChangePasswordRequest 修改密码请求
...@@ -72,14 +85,16 @@ type ChangePasswordRequest struct { ...@@ -72,14 +85,16 @@ type ChangePasswordRequest struct {
// UserService 用户服务 // UserService 用户服务
type UserService struct { type UserService struct {
userRepo UserRepository userRepo UserRepository
settingRepo SettingRepository
authCacheInvalidator APIKeyAuthCacheInvalidator authCacheInvalidator APIKeyAuthCacheInvalidator
billingCache BillingCache billingCache BillingCache
} }
// NewUserService 创建用户服务实例 // NewUserService 创建用户服务实例
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService { func NewUserService(userRepo UserRepository, settingRepo SettingRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
return &UserService{ return &UserService{
userRepo: userRepo, userRepo: userRepo,
settingRepo: settingRepo,
authCacheInvalidator: authCacheInvalidator, authCacheInvalidator: authCacheInvalidator,
billingCache: billingCache, billingCache: billingCache,
} }
...@@ -132,6 +147,17 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat ...@@ -132,6 +147,17 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Concurrency = *req.Concurrency user.Concurrency = *req.Concurrency
} }
if req.BalanceNotifyEnabled != nil {
user.BalanceNotifyEnabled = *req.BalanceNotifyEnabled
}
if req.BalanceNotifyThreshold != nil {
if *req.BalanceNotifyThreshold <= 0 {
user.BalanceNotifyThreshold = nil // clear to system default
} else {
user.BalanceNotifyThreshold = req.BalanceNotifyThreshold
}
}
if err := s.userRepo.Update(ctx, user); err != nil { if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err) return nil, fmt.Errorf("update user: %w", err)
} }
...@@ -198,10 +224,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl ...@@ -198,10 +224,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
} }
if s.billingCache != nil { if s.billingCache != nil {
go func() { go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in balance cache invalidation", "user_id", userID, "recover", r)
}
}()
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil { if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) slog.Error("invalidate user balance cache failed", "user_id", userID, "error", err)
} }
}() }()
} }
...@@ -248,3 +279,229 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error { ...@@ -248,3 +279,229 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
} }
return nil return nil
} }
// SendNotifyEmailCode sends a verification code to the extra notification email.
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error {
if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil {
return err
}
code, err := emailService.GenerateVerifyCode()
if err != nil {
return fmt.Errorf("generate code: %w", err)
}
// Send email first — if SMTP fails, don't write cache or increment counters,
// so the user is not locked out by cooldown/rate-limit for a code they never received.
if err := s.sendNotifyVerifyEmail(ctx, emailService, email, code); err != nil {
return err
}
if err := saveNotifyVerifyCode(ctx, cache, email, code); err != nil {
return err
}
// Increment user-level counter after successful save
if _, err := cache.IncrNotifyCodeUserRate(ctx, userID, notifyCodeUserRateWindow); err != nil {
slog.Error("failed to increment notify code user rate", "user_id", userID, "error", err)
}
return nil
}
// checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit.
func checkNotifyCodeRateLimit(ctx context.Context, cache EmailCache, userID int64, email string) error {
existing, err := cache.GetNotifyVerifyCode(ctx, email)
if err == nil && existing != nil {
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
return ErrVerifyCodeTooFrequent
}
}
count, err := cache.GetNotifyCodeUserRate(ctx, userID)
if err == nil && count >= notifyCodeUserRateLimit {
return ErrNotifyCodeUserRateLimit
}
return nil
}
// saveNotifyVerifyCode saves the verification code to cache.
func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code string) error {
data := &VerificationCodeData{
Code: code,
Attempts: 0,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(verifyCodeTTL),
}
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err)
}
return nil
}
// sendNotifyVerifyEmail builds and sends the verification email.
func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, email, code string) error {
siteName := "Sub2API"
if s.settingRepo != nil {
if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" {
siteName = name
}
}
subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName)
body := buildNotifyVerifyEmailBody(code, siteName)
return emailService.SendEmail(ctx, email, subject, body)
}
// VerifyAndAddNotifyEmail verifies the code and adds the email to user's extra emails.
func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, email, code string, cache EmailCache) error {
if err := verifyNotifyCode(ctx, cache, email, code); err != nil {
return err
}
_ = cache.DeleteNotifyVerifyCode(ctx, email)
return s.addOrVerifyNotifyEmail(ctx, userID, email)
}
// verifyNotifyCode validates the verification code against the cached data.
func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string) error {
data, err := cache.GetNotifyVerifyCode(ctx, email)
if err != nil || data == nil {
return ErrInvalidVerifyCode
}
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++
remaining := time.Until(data.ExpiresAt)
if remaining <= 0 {
return ErrInvalidVerifyCode
}
if err := cache.SetNotifyVerifyCode(ctx, email, data, remaining); err != nil {
slog.Error("failed to update notify verify code attempts", "email", email, "error", err)
}
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
return ErrInvalidVerifyCode
}
return nil
}
// addOrVerifyNotifyEmail adds the email to user's extra notification emails or marks it as verified.
// Note: concurrent calls for the same user could race on the read-modify-write of
// BalanceNotifyExtraEmails. The window is small (requires two verify flows completing
// simultaneously), and the worst case is a duplicate entry which is harmless.
func (s *UserService) addOrVerifyNotifyEmail(ctx context.Context, userID int64, email string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
for i, e := range user.BalanceNotifyExtraEmails {
if strings.EqualFold(e.Email, email) {
if !e.Verified {
user.BalanceNotifyExtraEmails[i].Verified = true
return s.userRepo.Update(ctx, user)
}
return nil // Already verified
}
}
if len(user.BalanceNotifyExtraEmails) >= maxNotifyEmails {
return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d notification emails allowed", maxNotifyEmails))
}
user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, NotifyEmailEntry{
Email: email,
Disabled: false,
Verified: true,
})
return s.userRepo.Update(ctx, user)
}
// RemoveNotifyEmail removes an email from user's extra notification emails.
func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
filtered := make([]NotifyEmailEntry, 0, len(user.BalanceNotifyExtraEmails))
found := false
for _, e := range user.BalanceNotifyExtraEmails {
if strings.EqualFold(e.Email, email) {
found = true
} else {
filtered = append(filtered, e)
}
}
if !found {
return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found")
}
user.BalanceNotifyExtraEmails = filtered
return s.userRepo.Update(ctx, user)
}
// ToggleNotifyEmail toggles the disabled state of a notification email entry.
func (s *UserService) ToggleNotifyEmail(ctx context.Context, userID int64, email string, disabled bool) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
found := false
for i, e := range user.BalanceNotifyExtraEmails {
if strings.EqualFold(e.Email, email) {
user.BalanceNotifyExtraEmails[i].Disabled = disabled
found = true
break
}
}
if !found {
return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found")
}
return s.userRepo.Update(ctx, user)
}
// notifyVerifyEmailTemplate is the HTML template for notify email verification.
// Format args: siteName, code.
const notifyVerifyEmailTemplate = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.code { font-size: 36px; font-weight: bold; letter-spacing: 8px; color: #333; background-color: #f8f9fa; padding: 20px 30px; border-radius: 8px; display: inline-block; margin: 20px 0; font-family: monospace; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">通知邮箱验证码 / Notification Email Verification</p>
<div class="code">%s</div>
<div class="info">
<p>您正在添加额外的通知邮箱,请输入此验证码完成验证。</p>
<p>You are adding an extra notification email. Please enter this code to verify.</p>
<p>此验证码将在 <strong>15 分钟</strong>后失效。</p>
<p>This code will expire in <strong>15 minutes</strong>.</p>
<p>如果您没有请求此验证码,请忽略此邮件。</p>
<p>If you did not request this code, please ignore this email.</p>
</div>
</div>
<div class="footer">
<p>此邮件由系统自动发送,请勿回复。/ This is an automated message, please do not reply.</p>
</div>
</div>
</body>
</html>`
// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification.
func buildNotifyVerifyEmailBody(code, siteName string) string {
return fmt.Sprintf(notifyVerifyEmailTemplate, siteName, code)
}
...@@ -46,12 +46,12 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int ...@@ -46,12 +46,12 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
return 0, nil return 0, nil
} }
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil return nil
} }
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
// --- mock: APIKeyAuthCacheInvalidator --- // --- mock: APIKeyAuthCacheInvalidator ---
...@@ -117,7 +117,7 @@ func (m *mockBillingCache) InvalidateAPIKeyRateLimit(context.Context, int64) err ...@@ -117,7 +117,7 @@ func (m *mockBillingCache) InvalidateAPIKeyRateLimit(context.Context, int64) err
func TestUpdateBalance_Success(t *testing.T) { func TestUpdateBalance_Success(t *testing.T) {
repo := &mockUserRepo{} repo := &mockUserRepo{}
cache := &mockBillingCache{} cache := &mockBillingCache{}
svc := NewUserService(repo, nil, cache) svc := NewUserService(repo, nil, nil, cache)
err := svc.UpdateBalance(context.Background(), 42, 100.0) err := svc.UpdateBalance(context.Background(), 42, 100.0)
require.NoError(t, err) require.NoError(t, err)
...@@ -134,7 +134,7 @@ func TestUpdateBalance_Success(t *testing.T) { ...@@ -134,7 +134,7 @@ func TestUpdateBalance_Success(t *testing.T) {
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) { func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
repo := &mockUserRepo{} repo := &mockUserRepo{}
svc := NewUserService(repo, nil, nil) // billingCache = nil svc := NewUserService(repo, nil, nil, nil) // billingCache = nil
err := svc.UpdateBalance(context.Background(), 1, 50.0) err := svc.UpdateBalance(context.Background(), 1, 50.0)
require.NoError(t, err, "billingCache 为 nil 时不应 panic") require.NoError(t, err, "billingCache 为 nil 时不应 panic")
...@@ -143,7 +143,7 @@ func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) { ...@@ -143,7 +143,7 @@ func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) { func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
repo := &mockUserRepo{} repo := &mockUserRepo{}
cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")} cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")}
svc := NewUserService(repo, nil, cache) svc := NewUserService(repo, nil, nil, cache)
err := svc.UpdateBalance(context.Background(), 99, 200.0) err := svc.UpdateBalance(context.Background(), 99, 200.0)
require.NoError(t, err, "缓存失效失败不应影响主流程返回值") require.NoError(t, err, "缓存失效失败不应影响主流程返回值")
...@@ -157,7 +157,7 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) { ...@@ -157,7 +157,7 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) { func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")} repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
cache := &mockBillingCache{} cache := &mockBillingCache{}
svc := NewUserService(repo, nil, cache) svc := NewUserService(repo, nil, nil, cache)
err := svc.UpdateBalance(context.Background(), 1, 100.0) err := svc.UpdateBalance(context.Background(), 1, 100.0)
require.Error(t, err, "repo 失败时应返回错误") require.Error(t, err, "repo 失败时应返回错误")
...@@ -173,7 +173,7 @@ func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) { ...@@ -173,7 +173,7 @@ func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) {
repo := &mockUserRepo{} repo := &mockUserRepo{}
auth := &mockAuthCacheInvalidator{} auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{} cache := &mockBillingCache{}
svc := NewUserService(repo, auth, cache) svc := NewUserService(repo, nil, auth, cache)
err := svc.UpdateBalance(context.Background(), 77, 300.0) err := svc.UpdateBalance(context.Background(), 77, 300.0)
require.NoError(t, err) require.NoError(t, err)
...@@ -194,7 +194,7 @@ func TestNewUserService_FieldsAssignment(t *testing.T) { ...@@ -194,7 +194,7 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
auth := &mockAuthCacheInvalidator{} auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{} cache := &mockBillingCache{}
svc := NewUserService(repo, auth, cache) svc := NewUserService(repo, nil, auth, cache)
require.NotNil(t, svc) require.NotNil(t, svc)
require.Equal(t, repo, svc.userRepo) require.Equal(t, repo, svc.userRepo)
require.Equal(t, auth, svc.authCacheInvalidator) require.Equal(t, auth, svc.authCacheInvalidator)
......
package service
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"sync/atomic"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"golang.org/x/sync/singleflight"
)
// WebSearchEmulationConfig holds the global web search emulation configuration.
type WebSearchEmulationConfig struct {
Enabled bool `json:"enabled"`
Providers []WebSearchProviderConfig `json:"providers"`
}
// WebSearchProviderConfig describes a single search provider (Brave or Tavily).
type WebSearchProviderConfig struct {
Type string `json:"type"` // websearch.ProviderTypeBrave | Tavily
APIKey string `json:"api_key,omitempty"` // secret — omitted in API responses
APIKeyConfigured bool `json:"api_key_configured"` // read-only mask
QuotaLimit *int64 `json:"quota_limit"` // nil = unlimited, >0 = limited
SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly
QuotaUsed int64 `json:"quota_used,omitempty"` // read-only: current usage from Redis
ProxyID *int64 `json:"proxy_id"` // optional proxy association
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration timestamp
}
// --- Validation ---
const maxWebSearchProviders = 10
var validProviderTypes = map[string]bool{
websearch.ProviderTypeBrave: true,
websearch.ProviderTypeTavily: true,
}
func validateWebSearchConfig(cfg *WebSearchEmulationConfig) error {
if cfg == nil {
return nil
}
if len(cfg.Providers) > maxWebSearchProviders {
return fmt.Errorf("too many providers (max %d)", maxWebSearchProviders)
}
seen := make(map[string]bool, len(cfg.Providers))
for i, p := range cfg.Providers {
if !validProviderTypes[p.Type] {
return fmt.Errorf("provider[%d]: invalid type %q", i, p.Type)
}
if p.QuotaLimit != nil && *p.QuotaLimit < 0 {
return fmt.Errorf("provider[%d]: quota_limit must be > 0 or null", i)
}
if seen[p.Type] {
return fmt.Errorf("provider[%d]: duplicate type %q", i, p.Type)
}
seen[p.Type] = true
}
return nil
}
// --- In-process cache (same pattern as gateway forwarding settings) ---
const sfKeyWebSearchConfig = "web_search_emulation_config"
type cachedWebSearchEmulationConfig struct {
config *WebSearchEmulationConfig
expiresAt int64 // unix nano
}
var webSearchEmulationCache atomic.Value // *cachedWebSearchEmulationConfig
var webSearchEmulationSF singleflight.Group
const (
webSearchEmulationCacheTTL = 60 * time.Second
webSearchEmulationErrorTTL = 5 * time.Second
webSearchEmulationDBTimeout = 5 * time.Second
)
// GetWebSearchEmulationConfig returns the configuration with in-process cache + singleflight.
func (s *SettingService) GetWebSearchEmulationConfig(ctx context.Context) (*WebSearchEmulationConfig, error) {
if cached := webSearchEmulationCache.Load(); cached != nil {
if c, ok := cached.(*cachedWebSearchEmulationConfig); ok && time.Now().UnixNano() < c.expiresAt {
return c.config, nil
}
}
result, err, _ := webSearchEmulationSF.Do(sfKeyWebSearchConfig, func() (any, error) {
return s.loadWebSearchConfigFromDB()
})
if err != nil {
return &WebSearchEmulationConfig{}, err
}
if cfg, ok := result.(*WebSearchEmulationConfig); ok {
return cfg, nil
}
return &WebSearchEmulationConfig{}, nil
}
func (s *SettingService) loadWebSearchConfigFromDB() (*WebSearchEmulationConfig, error) {
dbCtx, cancel := context.WithTimeout(context.Background(), webSearchEmulationDBTimeout)
defer cancel()
raw, err := s.settingRepo.GetValue(dbCtx, SettingKeyWebSearchEmulationConfig)
if err != nil {
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
config: &WebSearchEmulationConfig{},
expiresAt: time.Now().Add(webSearchEmulationErrorTTL).UnixNano(),
})
return &WebSearchEmulationConfig{}, err
}
cfg := parseWebSearchConfigJSON(raw)
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
config: cfg,
expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
})
return cfg, nil
}
func parseWebSearchConfigJSON(raw string) *WebSearchEmulationConfig {
cfg := &WebSearchEmulationConfig{}
if raw == "" {
return cfg
}
if err := json.Unmarshal([]byte(raw), cfg); err != nil {
slog.Warn("websearch: failed to parse config JSON", "error", err)
return &WebSearchEmulationConfig{}
}
return cfg
}
// SaveWebSearchEmulationConfig validates and persists the configuration.
// Empty API keys in the input are preserved from the existing config.
func (s *SettingService) SaveWebSearchEmulationConfig(ctx context.Context, cfg *WebSearchEmulationConfig) error {
if err := validateWebSearchConfig(cfg); err != nil {
return infraerrors.BadRequest("INVALID_WEB_SEARCH_CONFIG", err.Error())
}
s.mergeExistingAPIKeys(ctx, cfg)
// After merge, validate all enabled providers have API keys
if cfg.Enabled {
for _, p := range cfg.Providers {
if p.APIKey == "" {
return infraerrors.BadRequest("MISSING_API_KEY",
fmt.Sprintf("provider %s has no API key configured", p.Type))
}
}
}
data, err := json.Marshal(cfg)
if err != nil {
return fmt.Errorf("websearch: marshal config: %w", err)
}
if err := s.settingRepo.Set(ctx, SettingKeyWebSearchEmulationConfig, string(data)); err != nil {
return fmt.Errorf("websearch: save config: %w", err)
}
// Invalidate: forget singleflight first, then store new value
webSearchEmulationSF.Forget(sfKeyWebSearchConfig)
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
config: cfg,
expiresAt: time.Now().Add(webSearchEmulationCacheTTL).UnixNano(),
})
// Hot-reload: rebuild the global Manager with new config
s.rebuildWebSearchManager(ctx)
return nil
}
// mergeExistingAPIKeys preserves API keys from the current config when incoming value is empty.
func (s *SettingService) mergeExistingAPIKeys(ctx context.Context, cfg *WebSearchEmulationConfig) {
existing, _ := s.getWebSearchEmulationConfigRaw(ctx)
if existing == nil || cfg == nil {
return
}
existingByType := make(map[string]string, len(existing.Providers))
for _, p := range existing.Providers {
if p.APIKey != "" {
existingByType[p.Type] = p.APIKey
}
}
for i := range cfg.Providers {
if cfg.Providers[i].APIKey == "" {
if key, ok := existingByType[cfg.Providers[i].Type]; ok {
cfg.Providers[i].APIKey = key
}
}
}
}
func (s *SettingService) getWebSearchEmulationConfigRaw(ctx context.Context) (*WebSearchEmulationConfig, error) {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyWebSearchEmulationConfig)
if err != nil {
return nil, err
}
return parseWebSearchConfigJSON(raw), nil
}
// IsWebSearchEmulationEnabled is a quick check for whether the global switch is on.
func (s *SettingService) IsWebSearchEmulationEnabled(ctx context.Context) bool {
cfg, err := s.GetWebSearchEmulationConfig(ctx)
if err != nil {
return false
}
return cfg.Enabled && len(cfg.Providers) > 0
}
// SetWebSearchManagerBuilder injects a callback that creates and wires a websearch.Manager.
// The infra layer (main/wire) provides this builder, keeping redis out of the service layer.
// Triggers initial build.
func (s *SettingService) SetWebSearchManagerBuilder(ctx context.Context, builder WebSearchManagerBuilder) {
s.webSearchManagerBuilder = builder
s.rebuildWebSearchManager(ctx)
}
// rebuildWebSearchManager reads the current config, resolves proxy URLs, and invokes the builder.
func (s *SettingService) rebuildWebSearchManager(ctx context.Context) {
if s.webSearchManagerBuilder == nil {
return
}
cfg, err := s.GetWebSearchEmulationConfig(ctx)
if err != nil {
SetWebSearchManager(nil)
return
}
proxyURLs := s.resolveProviderProxyURLs(ctx, cfg)
s.webSearchManagerBuilder(cfg, proxyURLs)
}
// resolveProviderProxyURLs collects proxy IDs from providers and resolves them to URLs.
func (s *SettingService) resolveProviderProxyURLs(ctx context.Context, cfg *WebSearchEmulationConfig) map[int64]string {
if cfg == nil || s.proxyRepo == nil {
return nil
}
var ids []int64
for _, p := range cfg.Providers {
if p.ProxyID != nil && *p.ProxyID > 0 {
ids = append(ids, *p.ProxyID)
}
}
if len(ids) == 0 {
return nil
}
proxies, err := s.proxyRepo.ListByIDs(ctx, ids)
if err != nil {
slog.Warn("websearch: failed to resolve proxy URLs", "error", err)
return nil
}
result := make(map[int64]string, len(proxies))
for _, px := range proxies {
result[px.ID] = px.URL()
}
return result
}
// WebSearchTestResult holds the result of a search test.
type WebSearchTestResult struct {
Provider string `json:"provider"`
Results []websearch.SearchResult `json:"results"`
Query string `json:"query"`
}
// TestWebSearch executes a test search using the currently configured Manager.
// Uses Manager.TestSearch which bypasses quota tracking.
const testSearchTimeout = 15 * time.Second
func TestWebSearch(ctx context.Context, query string) (*WebSearchTestResult, error) {
mgr := getWebSearchManager()
if mgr == nil {
return nil, fmt.Errorf("web search: manager not initialized, save config first")
}
testCtx, cancel := context.WithTimeout(ctx, testSearchTimeout)
defer cancel()
resp, providerName, err := mgr.TestSearch(testCtx, websearch.SearchRequest{
Query: query,
MaxResults: webSearchDefaultMaxResults,
})
if err != nil {
return nil, err
}
return &WebSearchTestResult{
Provider: providerName,
Results: resp.Results,
Query: resp.Query,
}, nil
}
// PopulateWebSearchUsage returns a copy with quota usage populated from Redis (api_key kept as-is).
func PopulateWebSearchUsage(ctx context.Context, cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig {
if cfg == nil {
return nil
}
out := *cfg
out.Providers = make([]WebSearchProviderConfig, len(cfg.Providers))
mgr := getWebSearchManager()
for i, p := range cfg.Providers {
out.Providers[i] = p
out.Providers[i].APIKeyConfigured = p.APIKey != ""
if mgr != nil {
used, _ := mgr.GetUsage(ctx, p.Type)
out.Providers[i].QuotaUsed = used
}
}
return &out
}
// ResetWebSearchUsage deletes the Redis quota key for the given provider type.
func ResetWebSearchUsage(ctx context.Context, providerType string) error {
mgr := getWebSearchManager()
if mgr == nil {
return fmt.Errorf("web search manager not initialized")
}
return mgr.ResetUsage(ctx, providerType)
}
// SanitizeWebSearchConfig returns a copy with api_key fields masked and quota usage populated.
func SanitizeWebSearchConfig(ctx context.Context, cfg *WebSearchEmulationConfig) *WebSearchEmulationConfig {
if cfg == nil {
return nil
}
out := *cfg
out.Providers = make([]WebSearchProviderConfig, len(cfg.Providers))
// Load usage from the global Manager (reads from Redis)
mgr := getWebSearchManager()
for i, p := range cfg.Providers {
out.Providers[i] = p
out.Providers[i].APIKeyConfigured = p.APIKey != ""
out.Providers[i].APIKey = "" // never return the secret
// Populate quota usage from Redis
if mgr != nil {
used, _ := mgr.GetUsage(ctx, p.Type)
out.Providers[i].QuotaUsed = used
}
}
return &out
}
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/stretchr/testify/require"
)
// --- validateWebSearchConfig ---
func TestValidateWebSearchConfig_Nil(t *testing.T) {
require.NoError(t, validateWebSearchConfig(nil))
}
func TestValidateWebSearchConfig_Valid(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{
{Type: "brave", QuotaLimit: int64Ptr(1000)},
{Type: "tavily", QuotaLimit: int64Ptr(500)},
},
}
require.NoError(t, validateWebSearchConfig(cfg))
}
func TestValidateWebSearchConfig_TooManyProviders(t *testing.T) {
cfg := &WebSearchEmulationConfig{Providers: make([]WebSearchProviderConfig, 11)}
for i := range cfg.Providers {
cfg.Providers[i] = WebSearchProviderConfig{Type: "brave"}
}
err := validateWebSearchConfig(cfg)
require.ErrorContains(t, err, "too many providers")
}
func TestValidateWebSearchConfig_InvalidType(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "bing"}},
}
require.ErrorContains(t, validateWebSearchConfig(cfg), "invalid type")
}
func TestValidateWebSearchConfig_NegativeQuotaLimit(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: int64Ptr(-1)}},
}
require.ErrorContains(t, validateWebSearchConfig(cfg), "quota_limit must be > 0 or null")
}
func TestValidateWebSearchConfig_DuplicateType(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{
{Type: "brave"},
{Type: "brave"},
},
}
require.ErrorContains(t, validateWebSearchConfig(cfg), "duplicate type")
}
func TestValidateWebSearchConfig_NilQuotaLimit(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", QuotaLimit: nil}},
}
require.NoError(t, validateWebSearchConfig(cfg))
}
// --- parseWebSearchConfigJSON ---
func TestParseWebSearchConfigJSON_ValidJSON(t *testing.T) {
raw := `{"enabled":true,"providers":[{"type":"brave","api_key":"sk-xxx"}]}`
cfg := parseWebSearchConfigJSON(raw)
require.True(t, cfg.Enabled)
require.Len(t, cfg.Providers, 1)
require.Equal(t, "brave", cfg.Providers[0].Type)
}
func TestParseWebSearchConfigJSON_EmptyString(t *testing.T) {
cfg := parseWebSearchConfigJSON("")
require.False(t, cfg.Enabled)
require.Empty(t, cfg.Providers)
}
func TestParseWebSearchConfigJSON_InvalidJSON(t *testing.T) {
cfg := parseWebSearchConfigJSON("not{json")
require.False(t, cfg.Enabled)
require.Empty(t, cfg.Providers)
}
func TestParseWebSearchConfigJSON_BackwardCompatibility(t *testing.T) {
// Old config with priority and quota_refresh_interval should parse without error
raw := `{"enabled":true,"providers":[{"type":"brave","priority":1,"quota_refresh_interval":"monthly","quota_limit":1000}]}`
cfg := parseWebSearchConfigJSON(raw)
require.True(t, cfg.Enabled)
require.Len(t, cfg.Providers, 1)
require.Equal(t, int64(1000), *cfg.Providers[0].QuotaLimit)
}
// --- SanitizeWebSearchConfig ---
func TestSanitizeWebSearchConfig_MaskAPIKey(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{
{Type: "brave", APIKey: "sk-secret-xxx"},
},
}
out := SanitizeWebSearchConfig(context.Background(), cfg)
require.Equal(t, "", out.Providers[0].APIKey)
require.True(t, out.Providers[0].APIKeyConfigured)
}
func TestSanitizeWebSearchConfig_NoAPIKey(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: ""}},
}
out := SanitizeWebSearchConfig(context.Background(), cfg)
require.Equal(t, "", out.Providers[0].APIKey)
require.False(t, out.Providers[0].APIKeyConfigured)
}
func TestSanitizeWebSearchConfig_Nil(t *testing.T) {
require.Nil(t, SanitizeWebSearchConfig(context.Background(), nil))
}
func TestSanitizeWebSearchConfig_PreservesOtherFields(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{
{Type: "brave", APIKey: "secret", QuotaLimit: int64Ptr(1000)},
},
}
out := SanitizeWebSearchConfig(context.Background(), cfg)
require.True(t, out.Enabled)
require.Equal(t, int64(1000), *out.Providers[0].QuotaLimit)
}
func TestSanitizeWebSearchConfig_DoesNotMutateOriginal(t *testing.T) {
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "secret"}},
}
_ = SanitizeWebSearchConfig(context.Background(), cfg)
require.Equal(t, "secret", cfg.Providers[0].APIKey)
}
// --- PopulateWebSearchUsage ---
func TestPopulateWebSearchUsage_NilInput(t *testing.T) {
require.Nil(t, PopulateWebSearchUsage(context.Background(), nil))
}
func TestPopulateWebSearchUsage_NoManager_QuotaUsedZero(t *testing.T) {
// Ensure no global manager is set
SetWebSearchManager(nil)
defer SetWebSearchManager(nil)
cfg := &WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{
{Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(1000)},
},
}
out := PopulateWebSearchUsage(context.Background(), cfg)
require.NotNil(t, out)
require.Len(t, out.Providers, 1)
require.Equal(t, int64(0), out.Providers[0].QuotaUsed)
}
func TestPopulateWebSearchUsage_APIKeyConfigured_True(t *testing.T) {
SetWebSearchManager(nil)
defer SetWebSearchManager(nil)
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{
{Type: "brave", APIKey: "sk-key"},
},
}
out := PopulateWebSearchUsage(context.Background(), cfg)
require.True(t, out.Providers[0].APIKeyConfigured)
}
func TestPopulateWebSearchUsage_APIKeyConfigured_False(t *testing.T) {
SetWebSearchManager(nil)
defer SetWebSearchManager(nil)
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{
{Type: "brave", APIKey: ""},
},
}
out := PopulateWebSearchUsage(context.Background(), cfg)
require.False(t, out.Providers[0].APIKeyConfigured)
}
func TestPopulateWebSearchUsage_NilQuotaLimit(t *testing.T) {
SetWebSearchManager(nil)
defer SetWebSearchManager(nil)
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{
{Type: "brave", APIKey: "sk-key", QuotaLimit: nil},
},
}
out := PopulateWebSearchUsage(context.Background(), cfg)
require.Nil(t, out.Providers[0].QuotaLimit)
}
func TestPopulateWebSearchUsage_NonNilQuotaLimit(t *testing.T) {
SetWebSearchManager(nil)
defer SetWebSearchManager(nil)
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{
{Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(500)},
},
}
out := PopulateWebSearchUsage(context.Background(), cfg)
require.NotNil(t, out.Providers[0].QuotaLimit)
require.Equal(t, int64(500), *out.Providers[0].QuotaLimit)
}
func TestPopulateWebSearchUsage_WithManager_NilRedis(t *testing.T) {
// Manager with nil Redis returns 0 usage without error
mgr := websearch.NewManager([]websearch.ProviderConfig{
{Type: "brave", APIKey: "k"},
}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{
{Type: "brave", APIKey: "sk-key", QuotaLimit: int64Ptr(1000)},
},
}
out := PopulateWebSearchUsage(context.Background(), cfg)
require.Equal(t, int64(0), out.Providers[0].QuotaUsed)
require.True(t, out.Providers[0].APIKeyConfigured)
}
func TestPopulateWebSearchUsage_DoesNotMutateOriginal(t *testing.T) {
SetWebSearchManager(nil)
defer SetWebSearchManager(nil)
cfg := &WebSearchEmulationConfig{
Providers: []WebSearchProviderConfig{
{Type: "brave", APIKey: "secret", QuotaLimit: int64Ptr(100)},
},
}
_ = PopulateWebSearchUsage(context.Background(), cfg)
// Original should be unchanged
require.Equal(t, "secret", cfg.Providers[0].APIKey)
require.Equal(t, int64(0), cfg.Providers[0].QuotaUsed)
}
// --- ResetWebSearchUsage ---
func TestResetWebSearchUsage_NilManager(t *testing.T) {
SetWebSearchManager(nil)
defer SetWebSearchManager(nil)
err := ResetWebSearchUsage(context.Background(), "brave")
require.Error(t, err)
require.Contains(t, err.Error(), "not initialized")
}
...@@ -373,10 +373,11 @@ func ProvideBackupService( ...@@ -373,10 +373,11 @@ func ProvideBackupService(
return svc return svc
} }
// ProvideSettingService wires SettingService with group reader for default subscription validation. // ProvideSettingService wires SettingService with group reader and proxy repo.
func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService { func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, proxyRepo ProxyRepository, cfg *config.Config) *SettingService {
svc := NewSettingService(settingRepo, cfg) svc := NewSettingService(settingRepo, cfg)
svc.SetDefaultSubscriptionGroupReader(groupRepo) svc.SetDefaultSubscriptionGroupReader(groupRepo)
svc.SetProxyRepository(proxyRepo)
return svc return svc
} }
...@@ -465,6 +466,7 @@ var ProviderSet = wire.NewSet( ...@@ -465,6 +466,7 @@ var ProviderSet = wire.NewSet(
ProvidePaymentConfigService, ProvidePaymentConfigService,
NewPaymentService, NewPaymentService,
ProvidePaymentOrderExpiryService, ProvidePaymentOrderExpiryService,
ProvideBalanceNotifyService,
) )
// ProvidePaymentConfigService wraps NewPaymentConfigService to accept the named // ProvidePaymentConfigService wraps NewPaymentConfigService to accept the named
...@@ -473,6 +475,11 @@ func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRep ...@@ -473,6 +475,11 @@ func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRep
return NewPaymentConfigService(entClient, settingRepo, []byte(key)) return NewPaymentConfigService(entClient, settingRepo, []byte(key))
} }
// ProvideBalanceNotifyService creates BalanceNotifyService
func ProvideBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountRepository) *BalanceNotifyService {
return NewBalanceNotifyService(emailService, settingRepo, accountRepo)
}
// ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService. // ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService.
func ProvidePaymentOrderExpiryService(paymentSvc *PaymentService) *PaymentOrderExpiryService { func ProvidePaymentOrderExpiryService(paymentSvc *PaymentService) *PaymentOrderExpiryService {
svc := NewPaymentOrderExpiryService(paymentSvc, 60*time.Second) svc := NewPaymentOrderExpiryService(paymentSvc, 60*time.Second)
......
...@@ -10,6 +10,8 @@ import ( ...@@ -10,6 +10,8 @@ import (
"io" "io"
"io/fs" "io/fs"
"net/http" "net/http"
"os"
"path/filepath"
"strings" "strings"
"time" "time"
...@@ -32,11 +34,12 @@ type PublicSettingsProvider interface { ...@@ -32,11 +34,12 @@ type PublicSettingsProvider interface {
// FrontendServer serves the embedded frontend with settings injection // FrontendServer serves the embedded frontend with settings injection
type FrontendServer struct { type FrontendServer struct {
distFS fs.FS distFS fs.FS
fileServer http.Handler fileServer http.Handler
baseHTML []byte baseHTML []byte
cache *HTMLCache cache *HTMLCache
settings PublicSettingsProvider settings PublicSettingsProvider
overrideDir string // local file override directory
} }
// NewFrontendServer creates a new frontend server with settings injection // NewFrontendServer creates a new frontend server with settings injection
...@@ -62,11 +65,12 @@ func NewFrontendServer(settingsProvider PublicSettingsProvider) (*FrontendServer ...@@ -62,11 +65,12 @@ func NewFrontendServer(settingsProvider PublicSettingsProvider) (*FrontendServer
cache.SetBaseHTML(baseHTML) cache.SetBaseHTML(baseHTML)
return &FrontendServer{ return &FrontendServer{
distFS: distFS, distFS: distFS,
fileServer: http.FileServer(http.FS(distFS)), fileServer: http.FileServer(http.FS(distFS)),
baseHTML: baseHTML, baseHTML: baseHTML,
cache: cache, cache: cache,
settings: settingsProvider, settings: settingsProvider,
overrideDir: filepath.Join("data", "public"),
}, nil }, nil
} }
...@@ -99,6 +103,11 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc { ...@@ -99,6 +103,11 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc {
return return
} }
// Try local override first
if s.tryServeOverride(c, cleanPath) {
return
}
// Serve static files normally // Serve static files normally
s.fileServer.ServeHTTP(c.Writer, c.Request) s.fileServer.ServeHTTP(c.Writer, c.Request)
c.Abort() c.Abort()
...@@ -114,6 +123,22 @@ func (s *FrontendServer) fileExists(path string) bool { ...@@ -114,6 +123,22 @@ func (s *FrontendServer) fileExists(path string) bool {
return true return true
} }
// tryServeOverride checks if a local override file exists and serves it.
// Files in overrideDir take precedence over embedded files.
func (s *FrontendServer) tryServeOverride(c *gin.Context, cleanPath string) bool {
if s.overrideDir == "" {
return false
}
filePath := filepath.Join(s.overrideDir, filepath.Clean("/"+cleanPath))
info, err := os.Stat(filePath)
if err != nil || info.IsDir() {
return false
}
c.File(filePath)
c.Abort()
return true
}
func (s *FrontendServer) serveIndexHTML(c *gin.Context) { func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
// Get nonce from context (generated by SecurityHeaders middleware) // Get nonce from context (generated by SecurityHeaders middleware)
nonce := middleware.GetNonceFromContext(c) nonce := middleware.GetNonceFromContext(c)
...@@ -226,6 +251,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { ...@@ -226,6 +251,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
panic("failed to get dist subdirectory: " + err.Error()) panic("failed to get dist subdirectory: " + err.Error())
} }
fileServer := http.FileServer(http.FS(distFS)) fileServer := http.FileServer(http.FS(distFS))
overrideDir := filepath.Join("data", "public")
return func(c *gin.Context) { return func(c *gin.Context) {
path := c.Request.URL.Path path := c.Request.URL.Path
...@@ -242,6 +268,10 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { ...@@ -242,6 +268,10 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
if file, err := distFS.Open(cleanPath); err == nil { if file, err := distFS.Open(cleanPath); err == nil {
_ = file.Close() _ = file.Close()
// Try local override first
if tryServeOverrideFile(c, overrideDir, cleanPath) {
return
}
fileServer.ServeHTTP(c.Writer, c.Request) fileServer.ServeHTTP(c.Writer, c.Request)
c.Abort() c.Abort()
return return
...@@ -251,6 +281,21 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { ...@@ -251,6 +281,21 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
} }
} }
// tryServeOverrideFile is a standalone version of tryServeOverride for legacy usage.
func tryServeOverrideFile(c *gin.Context, overrideDir, cleanPath string) bool {
if overrideDir == "" {
return false
}
filePath := filepath.Join(overrideDir, filepath.Clean("/"+cleanPath))
info, err := os.Stat(filePath)
if err != nil || info.IsDir() {
return false
}
c.File(filePath)
c.Abort()
return true
}
func shouldBypassEmbeddedFrontend(path string) bool { func shouldBypassEmbeddedFrontend(path string) bool {
trimmed := strings.TrimSpace(path) trimmed := strings.TrimSpace(path)
return strings.HasPrefix(trimmed, "/api/") || return strings.HasPrefix(trimmed, "/api/") ||
......
-- 在测试库执行完migration_release之后,执行这个语句,阻止token刷新
-- 切记不可在生产库执行
UPDATE accounts SET schedulable = false,credentials = NUll WHERE type = 'oauth';
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment