Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
0b746501
Commit
0b746501
authored
Apr 16, 2026
by
陈曦
Browse files
1. merge upstream v0.1.113 2.提交migration相关文件
parents
45061102
be7551b9
Changes
225
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/payment_config_service.go
View file @
0b746501
...
...
@@ -3,12 +3,14 @@ package service
import
(
"context"
"fmt"
"math"
"strconv"
"strings"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const
(
...
...
@@ -21,6 +23,8 @@ const (
SettingEnabledPaymentTypes
=
"ENABLED_PAYMENT_TYPES"
SettingLoadBalanceStrategy
=
"LOAD_BALANCE_STRATEGY"
SettingBalancePayDisabled
=
"BALANCE_PAYMENT_DISABLED"
SettingBalanceRechargeMult
=
"BALANCE_RECHARGE_MULTIPLIER"
SettingRechargeFeeRate
=
"RECHARGE_FEE_RATE"
SettingProductNamePrefix
=
"PRODUCT_NAME_PREFIX"
SettingProductNameSuffix
=
"PRODUCT_NAME_SUFFIX"
SettingHelpImageURL
=
"PAYMENT_HELP_IMAGE_URL"
...
...
@@ -40,20 +44,22 @@ const (
// PaymentConfig holds the payment system configuration.
type
PaymentConfig
struct
{
Enabled
bool
`json:"enabled"`
MinAmount
float64
`json:"min_amount"`
MaxAmount
float64
`json:"max_amount"`
DailyLimit
float64
`json:"daily_limit"`
OrderTimeoutMin
int
`json:"order_timeout_minutes"`
MaxPendingOrders
int
`json:"max_pending_orders"`
EnabledTypes
[]
string
`json:"enabled_payment_types"`
BalanceDisabled
bool
`json:"balance_disabled"`
LoadBalanceStrategy
string
`json:"load_balance_strategy"`
ProductNamePrefix
string
`json:"product_name_prefix"`
ProductNameSuffix
string
`json:"product_name_suffix"`
HelpImageURL
string
`json:"help_image_url"`
HelpText
string
`json:"help_text"`
StripePublishableKey
string
`json:"stripe_publishable_key,omitempty"`
Enabled
bool
`json:"enabled"`
MinAmount
float64
`json:"min_amount"`
MaxAmount
float64
`json:"max_amount"`
DailyLimit
float64
`json:"daily_limit"`
OrderTimeoutMin
int
`json:"order_timeout_minutes"`
MaxPendingOrders
int
`json:"max_pending_orders"`
EnabledTypes
[]
string
`json:"enabled_payment_types"`
BalanceDisabled
bool
`json:"balance_disabled"`
BalanceRechargeMultiplier
float64
`json:"balance_recharge_multiplier"`
RechargeFeeRate
float64
`json:"recharge_fee_rate"`
LoadBalanceStrategy
string
`json:"load_balance_strategy"`
ProductNamePrefix
string
`json:"product_name_prefix"`
ProductNameSuffix
string
`json:"product_name_suffix"`
HelpImageURL
string
`json:"help_image_url"`
HelpText
string
`json:"help_text"`
StripePublishableKey
string
`json:"stripe_publishable_key,omitempty"`
// Cancel rate limit settings
CancelRateLimitEnabled
bool
`json:"cancel_rate_limit_enabled"`
...
...
@@ -65,19 +71,21 @@ type PaymentConfig struct {
// UpdatePaymentConfigRequest contains fields to update payment configuration.
type
UpdatePaymentConfigRequest
struct
{
Enabled
*
bool
`json:"enabled"`
MinAmount
*
float64
`json:"min_amount"`
MaxAmount
*
float64
`json:"max_amount"`
DailyLimit
*
float64
`json:"daily_limit"`
OrderTimeoutMin
*
int
`json:"order_timeout_minutes"`
MaxPendingOrders
*
int
`json:"max_pending_orders"`
EnabledTypes
[]
string
`json:"enabled_payment_types"`
BalanceDisabled
*
bool
`json:"balance_disabled"`
LoadBalanceStrategy
*
string
`json:"load_balance_strategy"`
ProductNamePrefix
*
string
`json:"product_name_prefix"`
ProductNameSuffix
*
string
`json:"product_name_suffix"`
HelpImageURL
*
string
`json:"help_image_url"`
HelpText
*
string
`json:"help_text"`
Enabled
*
bool
`json:"enabled"`
MinAmount
*
float64
`json:"min_amount"`
MaxAmount
*
float64
`json:"max_amount"`
DailyLimit
*
float64
`json:"daily_limit"`
OrderTimeoutMin
*
int
`json:"order_timeout_minutes"`
MaxPendingOrders
*
int
`json:"max_pending_orders"`
EnabledTypes
[]
string
`json:"enabled_payment_types"`
BalanceDisabled
*
bool
`json:"balance_disabled"`
BalanceRechargeMultiplier
*
float64
`json:"balance_recharge_multiplier"`
RechargeFeeRate
*
float64
`json:"recharge_fee_rate"`
LoadBalanceStrategy
*
string
`json:"load_balance_strategy"`
ProductNamePrefix
*
string
`json:"product_name_prefix"`
ProductNameSuffix
*
string
`json:"product_name_suffix"`
HelpImageURL
*
string
`json:"help_image_url"`
HelpText
*
string
`json:"help_text"`
// Cancel rate limit settings
CancelRateLimitEnabled
*
bool
`json:"cancel_rate_limit_enabled"`
...
...
@@ -105,26 +113,28 @@ type MethodLimitsResponse struct {
}
type
CreateProviderInstanceRequest
struct
{
ProviderKey
string
`json:"provider_key"`
Name
string
`json:"name"`
Config
map
[
string
]
string
`json:"config"`
SupportedTypes
[]
string
`json:"supported_types"`
Enabled
bool
`json:"enabled"`
PaymentMode
string
`json:"payment_mode"`
SortOrder
int
`json:"sort_order"`
Limits
string
`json:"limits"`
RefundEnabled
bool
`json:"refund_enabled"`
ProviderKey
string
`json:"provider_key"`
Name
string
`json:"name"`
Config
map
[
string
]
string
`json:"config"`
SupportedTypes
[]
string
`json:"supported_types"`
Enabled
bool
`json:"enabled"`
PaymentMode
string
`json:"payment_mode"`
SortOrder
int
`json:"sort_order"`
Limits
string
`json:"limits"`
RefundEnabled
bool
`json:"refund_enabled"`
AllowUserRefund
bool
`json:"allow_user_refund"`
}
type
UpdateProviderInstanceRequest
struct
{
Name
*
string
`json:"name"`
Config
map
[
string
]
string
`json:"config"`
SupportedTypes
[]
string
`json:"supported_types"`
Enabled
*
bool
`json:"enabled"`
PaymentMode
*
string
`json:"payment_mode"`
SortOrder
*
int
`json:"sort_order"`
Limits
*
string
`json:"limits"`
RefundEnabled
*
bool
`json:"refund_enabled"`
Name
*
string
`json:"name"`
Config
map
[
string
]
string
`json:"config"`
SupportedTypes
[]
string
`json:"supported_types"`
Enabled
*
bool
`json:"enabled"`
PaymentMode
*
string
`json:"payment_mode"`
SortOrder
*
int
`json:"sort_order"`
Limits
*
string
`json:"limits"`
RefundEnabled
*
bool
`json:"refund_enabled"`
AllowUserRefund
*
bool
`json:"allow_user_refund"`
}
type
CreatePlanRequest
struct
{
GroupID
int64
`json:"group_id"`
...
...
@@ -181,7 +191,7 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
keys
:=
[]
string
{
SettingPaymentEnabled
,
SettingMinRechargeAmount
,
SettingMaxRechargeAmount
,
SettingDailyRechargeLimit
,
SettingOrderTimeoutMinutes
,
SettingMaxPendingOrders
,
SettingEnabledPaymentTypes
,
SettingBalancePayDisabled
,
SettingLoadBalanceStrategy
,
SettingEnabledPaymentTypes
,
SettingBalancePayDisabled
,
SettingBalanceRechargeMult
,
SettingRechargeFeeRate
,
SettingLoadBalanceStrategy
,
SettingProductNamePrefix
,
SettingProductNameSuffix
,
SettingHelpImageURL
,
SettingHelpText
,
SettingCancelRateLimitOn
,
SettingCancelRateLimitMax
,
...
...
@@ -199,18 +209,20 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
func
(
s
*
PaymentConfigService
)
parsePaymentConfig
(
vals
map
[
string
]
string
)
*
PaymentConfig
{
cfg
:=
&
PaymentConfig
{
Enabled
:
vals
[
SettingPaymentEnabled
]
==
"true"
,
MinAmount
:
pcParseFloat
(
vals
[
SettingMinRechargeAmount
],
1
),
MaxAmount
:
pcParseFloat
(
vals
[
SettingMaxRechargeAmount
],
0
),
DailyLimit
:
pcParseFloat
(
vals
[
SettingDailyRechargeLimit
],
0
),
OrderTimeoutMin
:
pcParseInt
(
vals
[
SettingOrderTimeoutMinutes
],
defaultOrderTimeoutMin
),
MaxPendingOrders
:
pcParseInt
(
vals
[
SettingMaxPendingOrders
],
defaultMaxPendingOrders
),
BalanceDisabled
:
vals
[
SettingBalancePayDisabled
]
==
"true"
,
LoadBalanceStrategy
:
vals
[
SettingLoadBalanceStrategy
],
ProductNamePrefix
:
vals
[
SettingProductNamePrefix
],
ProductNameSuffix
:
vals
[
SettingProductNameSuffix
],
HelpImageURL
:
vals
[
SettingHelpImageURL
],
HelpText
:
vals
[
SettingHelpText
],
Enabled
:
vals
[
SettingPaymentEnabled
]
==
"true"
,
MinAmount
:
pcParseFloat
(
vals
[
SettingMinRechargeAmount
],
1
),
MaxAmount
:
pcParseFloat
(
vals
[
SettingMaxRechargeAmount
],
0
),
DailyLimit
:
pcParseFloat
(
vals
[
SettingDailyRechargeLimit
],
0
),
OrderTimeoutMin
:
pcParseInt
(
vals
[
SettingOrderTimeoutMinutes
],
defaultOrderTimeoutMin
),
MaxPendingOrders
:
pcParseInt
(
vals
[
SettingMaxPendingOrders
],
defaultMaxPendingOrders
),
BalanceDisabled
:
vals
[
SettingBalancePayDisabled
]
==
"true"
,
BalanceRechargeMultiplier
:
normalizeBalanceRechargeMultiplier
(
pcParseFloat
(
vals
[
SettingBalanceRechargeMult
],
defaultBalanceRechargeMultiplier
)),
RechargeFeeRate
:
pcParseFloat
(
vals
[
SettingRechargeFeeRate
],
0
),
LoadBalanceStrategy
:
vals
[
SettingLoadBalanceStrategy
],
ProductNamePrefix
:
vals
[
SettingProductNamePrefix
],
ProductNameSuffix
:
vals
[
SettingProductNameSuffix
],
HelpImageURL
:
vals
[
SettingHelpImageURL
],
HelpText
:
vals
[
SettingHelpText
],
CancelRateLimitEnabled
:
vals
[
SettingCancelRateLimitOn
]
==
"true"
,
CancelRateLimitMax
:
pcParseInt
(
vals
[
SettingCancelRateLimitMax
],
10
),
...
...
@@ -254,6 +266,21 @@ func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) stri
// nil-check before serialisation — this is inherent to patch-style update patterns
// and cannot be meaningfully decomposed without introducing unnecessary abstraction.
func
(
s
*
PaymentConfigService
)
UpdatePaymentConfig
(
ctx
context
.
Context
,
req
UpdatePaymentConfigRequest
)
error
{
if
req
.
BalanceRechargeMultiplier
!=
nil
{
if
math
.
IsNaN
(
*
req
.
BalanceRechargeMultiplier
)
||
math
.
IsInf
(
*
req
.
BalanceRechargeMultiplier
,
0
)
||
*
req
.
BalanceRechargeMultiplier
<=
0
{
return
infraerrors
.
BadRequest
(
"INVALID_BALANCE_RECHARGE_MULTIPLIER"
,
"balance recharge multiplier must be greater than 0"
)
}
}
if
req
.
RechargeFeeRate
!=
nil
{
v
:=
*
req
.
RechargeFeeRate
if
math
.
IsNaN
(
v
)
||
math
.
IsInf
(
v
,
0
)
||
v
<
0
||
v
>
100
{
return
infraerrors
.
BadRequest
(
"INVALID_RECHARGE_FEE_RATE"
,
"recharge fee rate must be between 0 and 100"
)
}
// Enforce max 2 decimal places
if
math
.
Round
(
v
*
100
)
!=
v
*
100
{
return
infraerrors
.
BadRequest
(
"INVALID_RECHARGE_FEE_RATE"
,
"recharge fee rate allows at most 2 decimal places"
)
}
}
m
:=
map
[
string
]
string
{
SettingPaymentEnabled
:
formatBoolOrEmpty
(
req
.
Enabled
),
SettingMinRechargeAmount
:
formatPositiveFloat
(
req
.
MinAmount
),
...
...
@@ -262,6 +289,8 @@ func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req Upda
SettingOrderTimeoutMinutes
:
formatPositiveInt
(
req
.
OrderTimeoutMin
),
SettingMaxPendingOrders
:
formatPositiveInt
(
req
.
MaxPendingOrders
),
SettingBalancePayDisabled
:
formatBoolOrEmpty
(
req
.
BalanceDisabled
),
SettingBalanceRechargeMult
:
formatPositiveFloat
(
req
.
BalanceRechargeMultiplier
),
SettingRechargeFeeRate
:
formatNonNegativeFloat
(
req
.
RechargeFeeRate
),
SettingLoadBalanceStrategy
:
derefStr
(
req
.
LoadBalanceStrategy
),
SettingProductNamePrefix
:
derefStr
(
req
.
ProductNamePrefix
),
SettingProductNameSuffix
:
derefStr
(
req
.
ProductNameSuffix
),
...
...
@@ -295,6 +324,13 @@ func formatPositiveFloat(v *float64) string {
return
strconv
.
FormatFloat
(
*
v
,
'f'
,
2
,
64
)
}
func
formatNonNegativeFloat
(
v
*
float64
)
string
{
if
v
==
nil
||
*
v
<
0
{
return
""
}
return
strconv
.
FormatFloat
(
*
v
,
'f'
,
2
,
64
)
}
func
formatPositiveInt
(
v
*
int
)
string
{
if
v
==
nil
||
*
v
<=
0
{
return
""
...
...
backend/internal/service/payment_fulfillment.go
View file @
0b746501
...
...
@@ -216,7 +216,11 @@ func (s *PaymentService) markCompleted(ctx context.Context, o *dbent.PaymentOrde
if
err
!=
nil
{
return
fmt
.
Errorf
(
"mark completed: %w"
,
err
)
}
s
.
writeAuditLog
(
ctx
,
o
.
ID
,
auditAction
,
"system"
,
map
[
string
]
any
{
"rechargeCode"
:
o
.
RechargeCode
,
"amount"
:
o
.
Amount
})
s
.
writeAuditLog
(
ctx
,
o
.
ID
,
auditAction
,
"system"
,
map
[
string
]
any
{
"rechargeCode"
:
o
.
RechargeCode
,
"creditedAmount"
:
o
.
Amount
,
"payAmount"
:
o
.
PayAmount
,
})
return
nil
}
...
...
backend/internal/service/payment_order.go
View file @
0b746501
...
...
@@ -43,18 +43,22 @@ func (s *PaymentService) CreateOrder(ctx context.Context, req CreateOrderRequest
if
user
.
Status
!=
payment
.
EntityStatusActive
{
return
nil
,
infraerrors
.
Forbidden
(
"USER_INACTIVE"
,
"user account is disabled"
)
}
amount
:=
req
.
Amount
orderAmount
:=
req
.
Amount
limitAmount
:=
req
.
Amount
if
plan
!=
nil
{
amount
=
plan
.
Price
orderAmount
=
plan
.
Price
limitAmount
=
plan
.
Price
}
else
if
req
.
OrderType
==
payment
.
OrderTypeBalance
{
orderAmount
=
calculateCreditedBalance
(
req
.
Amount
,
cfg
.
BalanceRechargeMultiplier
)
}
feeRate
:=
s
.
ge
t
FeeRate
(
req
.
PaymentType
)
payAmountStr
:=
payment
.
CalculatePayAmount
(
a
mount
,
feeRate
)
feeRate
:=
cfg
.
Rechar
geFeeRate
payAmountStr
:=
payment
.
CalculatePayAmount
(
limitA
mount
,
feeRate
)
payAmount
,
_
:=
strconv
.
ParseFloat
(
payAmountStr
,
64
)
order
,
err
:=
s
.
createOrderInTx
(
ctx
,
req
,
user
,
plan
,
cfg
,
a
mount
,
feeRate
,
payAmount
)
order
,
err
:=
s
.
createOrderInTx
(
ctx
,
req
,
user
,
plan
,
cfg
,
orderAmount
,
limitA
mount
,
feeRate
,
payAmount
)
if
err
!=
nil
{
return
nil
,
err
}
resp
,
err
:=
s
.
invokeProvider
(
ctx
,
order
,
req
,
cfg
,
payAmountStr
,
payAmount
,
plan
)
resp
,
err
:=
s
.
invokeProvider
(
ctx
,
order
,
req
,
cfg
,
limitAmount
,
payAmountStr
,
payAmount
,
plan
)
if
err
!=
nil
{
_
,
_
=
s
.
entClient
.
PaymentOrder
.
UpdateOneID
(
order
.
ID
)
.
SetStatus
(
OrderStatusFailed
)
.
...
...
@@ -99,7 +103,7 @@ func (s *PaymentService) validateSubOrder(ctx context.Context, req CreateOrderRe
return
plan
,
nil
}
func
(
s
*
PaymentService
)
createOrderInTx
(
ctx
context
.
Context
,
req
CreateOrderRequest
,
user
*
User
,
plan
*
dbent
.
SubscriptionPlan
,
cfg
*
PaymentConfig
,
a
mount
,
feeRate
,
payAmount
float64
)
(
*
dbent
.
PaymentOrder
,
error
)
{
func
(
s
*
PaymentService
)
createOrderInTx
(
ctx
context
.
Context
,
req
CreateOrderRequest
,
user
*
User
,
plan
*
dbent
.
SubscriptionPlan
,
cfg
*
PaymentConfig
,
orderAmount
,
limitA
mount
,
feeRate
,
payAmount
float64
)
(
*
dbent
.
PaymentOrder
,
error
)
{
tx
,
err
:=
s
.
entClient
.
Tx
(
ctx
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"begin transaction: %w"
,
err
)
...
...
@@ -108,7 +112,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
if
err
:=
s
.
checkPendingLimit
(
ctx
,
tx
,
req
.
UserID
,
cfg
.
MaxPendingOrders
);
err
!=
nil
{
return
nil
,
err
}
if
err
:=
s
.
checkDailyLimit
(
ctx
,
tx
,
req
.
UserID
,
a
mount
,
cfg
.
DailyLimit
);
err
!=
nil
{
if
err
:=
s
.
checkDailyLimit
(
ctx
,
tx
,
req
.
UserID
,
limitA
mount
,
cfg
.
DailyLimit
);
err
!=
nil
{
return
nil
,
err
}
tm
:=
cfg
.
OrderTimeoutMin
...
...
@@ -121,7 +125,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
SetUserEmail
(
user
.
Email
)
.
SetUserName
(
user
.
Username
)
.
SetNillableUserNotes
(
psNilIfEmpty
(
user
.
Notes
))
.
SetAmount
(
a
mount
)
.
SetAmount
(
orderA
mount
)
.
SetPayAmount
(
payAmount
)
.
SetFeeRate
(
feeRate
)
.
SetRechargeCode
(
""
)
.
...
...
@@ -180,6 +184,10 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
}
var
used
float64
for
_
,
o
:=
range
orders
{
if
o
.
OrderType
==
payment
.
OrderTypeBalance
{
used
+=
o
.
PayAmount
continue
}
used
+=
o
.
Amount
}
if
used
+
amount
>
limit
{
...
...
@@ -188,7 +196,7 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
return
nil
}
func
(
s
*
PaymentService
)
invokeProvider
(
ctx
context
.
Context
,
order
*
dbent
.
PaymentOrder
,
req
CreateOrderRequest
,
cfg
*
PaymentConfig
,
payAmountStr
string
,
payAmount
float64
,
plan
*
dbent
.
SubscriptionPlan
)
(
*
CreateOrderResponse
,
error
)
{
func
(
s
*
PaymentService
)
invokeProvider
(
ctx
context
.
Context
,
order
*
dbent
.
PaymentOrder
,
req
CreateOrderRequest
,
cfg
*
PaymentConfig
,
limitAmount
float64
,
payAmountStr
string
,
payAmount
float64
,
plan
*
dbent
.
SubscriptionPlan
)
(
*
CreateOrderResponse
,
error
)
{
// Select an instance across all providers that support the requested payment type.
// This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay").
sel
,
err
:=
s
.
loadBalancer
.
SelectInstance
(
ctx
,
""
,
req
.
PaymentType
,
payment
.
Strategy
(
cfg
.
LoadBalanceStrategy
),
payAmount
)
...
...
@@ -202,7 +210,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
if
err
!=
nil
{
return
nil
,
infraerrors
.
ServiceUnavailable
(
"PAYMENT_GATEWAY_ERROR"
,
"payment method is temporarily unavailable"
)
}
subject
:=
s
.
buildPaymentSubject
(
plan
,
pay
Amount
Str
,
cfg
)
subject
:=
s
.
buildPaymentSubject
(
plan
,
limit
Amount
,
cfg
)
outTradeNo
:=
order
.
OutTradeNo
pr
,
err
:=
prov
.
CreatePayment
(
ctx
,
payment
.
CreatePaymentRequest
{
OrderID
:
outTradeNo
,
Amount
:
payAmountStr
,
PaymentType
:
req
.
PaymentType
,
Subject
:
subject
,
ClientIP
:
req
.
ClientIP
,
IsMobile
:
req
.
IsMobile
,
InstanceSubMethods
:
sel
.
SupportedTypes
})
if
err
!=
nil
{
...
...
@@ -213,23 +221,30 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update order with payment details: %w"
,
err
)
}
s
.
writeAuditLog
(
ctx
,
order
.
ID
,
"ORDER_CREATED"
,
fmt
.
Sprintf
(
"user:%d"
,
req
.
UserID
),
map
[
string
]
any
{
"amount"
:
req
.
Amount
,
"paymentType"
:
req
.
PaymentType
,
"orderType"
:
req
.
OrderType
})
s
.
writeAuditLog
(
ctx
,
order
.
ID
,
"ORDER_CREATED"
,
fmt
.
Sprintf
(
"user:%d"
,
req
.
UserID
),
map
[
string
]
any
{
"paymentAmount"
:
req
.
Amount
,
"creditedAmount"
:
order
.
Amount
,
"payAmount"
:
order
.
PayAmount
,
"paymentType"
:
req
.
PaymentType
,
"orderType"
:
req
.
OrderType
,
})
return
&
CreateOrderResponse
{
OrderID
:
order
.
ID
,
Amount
:
order
.
Amount
,
PayAmount
:
payAmount
,
FeeRate
:
order
.
FeeRate
,
Status
:
OrderStatusPending
,
PaymentType
:
req
.
PaymentType
,
PayURL
:
pr
.
PayURL
,
QRCode
:
pr
.
QRCode
,
ClientSecret
:
pr
.
ClientSecret
,
ExpiresAt
:
order
.
ExpiresAt
,
PaymentMode
:
sel
.
PaymentMode
},
nil
}
func
(
s
*
PaymentService
)
buildPaymentSubject
(
plan
*
dbent
.
SubscriptionPlan
,
pay
Amount
Str
string
,
cfg
*
PaymentConfig
)
string
{
func
(
s
*
PaymentService
)
buildPaymentSubject
(
plan
*
dbent
.
SubscriptionPlan
,
limit
Amount
float64
,
cfg
*
PaymentConfig
)
string
{
if
plan
!=
nil
{
if
plan
.
ProductName
!=
""
{
return
plan
.
ProductName
}
return
"Sub2API Subscription "
+
plan
.
Name
}
amountStr
:=
strconv
.
FormatFloat
(
limitAmount
,
'f'
,
2
,
64
)
pf
:=
strings
.
TrimSpace
(
cfg
.
ProductNamePrefix
)
sf
:=
strings
.
TrimSpace
(
cfg
.
ProductNameSuffix
)
if
pf
!=
""
||
sf
!=
""
{
return
strings
.
TrimSpace
(
pf
+
" "
+
payA
mountStr
+
" "
+
sf
)
return
strings
.
TrimSpace
(
pf
+
" "
+
a
mountStr
+
" "
+
sf
)
}
return
"Sub2API "
+
payA
mountStr
+
" CNY"
return
"Sub2API "
+
a
mountStr
+
" CNY"
}
// --- Order Queries ---
...
...
backend/internal/service/payment_refund.go
View file @
0b746501
...
...
@@ -2,6 +2,7 @@ package service
import
(
"context"
"errors"
"fmt"
"log/slog"
"math"
...
...
@@ -17,6 +18,19 @@ import (
// --- Refund Flow ---
// getOrderProviderInstance looks up the provider instance that processed this order.
// Returns nil, nil for legacy orders without provider_instance_id.
func
(
s
*
PaymentService
)
getOrderProviderInstance
(
ctx
context
.
Context
,
o
*
dbent
.
PaymentOrder
)
(
*
dbent
.
PaymentProviderInstance
,
error
)
{
if
o
.
ProviderInstanceID
==
nil
||
*
o
.
ProviderInstanceID
==
""
{
return
nil
,
nil
}
instID
,
err
:=
strconv
.
ParseInt
(
*
o
.
ProviderInstanceID
,
10
,
64
)
if
err
!=
nil
{
return
nil
,
nil
}
return
s
.
entClient
.
PaymentProviderInstance
.
Get
(
ctx
,
instID
)
}
func
(
s
*
PaymentService
)
RequestRefund
(
ctx
context
.
Context
,
oid
,
uid
int64
,
reason
string
)
error
{
o
,
err
:=
s
.
validateRefundRequest
(
ctx
,
oid
,
uid
)
if
err
!=
nil
{
...
...
@@ -57,6 +71,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
if
o
.
Status
!=
OrderStatusCompleted
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_STATUS"
,
"only completed orders can request refund"
)
}
// Check provider instance allows user refund
inst
,
err
:=
s
.
getOrderProviderInstance
(
ctx
,
o
)
if
err
!=
nil
||
inst
==
nil
{
return
nil
,
infraerrors
.
Forbidden
(
"USER_REFUND_DISABLED"
,
"refund is not available for this order"
)
}
if
!
inst
.
AllowUserRefund
{
return
nil
,
infraerrors
.
Forbidden
(
"USER_REFUND_DISABLED"
,
"user refund is not enabled for this provider"
)
}
return
o
,
nil
}
...
...
@@ -69,6 +91,19 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
if
!
psSliceContains
(
ok
,
o
.
Status
)
{
return
nil
,
nil
,
infraerrors
.
BadRequest
(
"INVALID_STATUS"
,
"order status does not allow refund"
)
}
// Check provider instance allows admin refund
inst
,
instErr
:=
s
.
getOrderProviderInstance
(
ctx
,
o
)
if
instErr
!=
nil
{
slog
.
Warn
(
"refund: provider instance lookup failed"
,
"orderID"
,
oid
,
"error"
,
instErr
)
return
nil
,
nil
,
infraerrors
.
InternalServer
(
"PROVIDER_LOOKUP_FAILED"
,
"failed to look up payment provider for this order"
)
}
if
inst
==
nil
{
// Legacy order without provider_instance_id — block refund
return
nil
,
nil
,
infraerrors
.
Forbidden
(
"REFUND_DISABLED"
,
"refund is not available for this order"
)
}
if
!
inst
.
RefundEnabled
{
return
nil
,
nil
,
infraerrors
.
Forbidden
(
"REFUND_DISABLED"
,
"refund is not enabled for this provider"
)
}
if
math
.
IsNaN
(
amt
)
||
math
.
IsInf
(
amt
,
0
)
{
return
nil
,
nil
,
infraerrors
.
BadRequest
(
"INVALID_AMOUNT"
,
"invalid refund amount"
)
}
...
...
@@ -78,11 +113,7 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
if
amt
-
o
.
Amount
>
amountToleranceCNY
{
return
nil
,
nil
,
infraerrors
.
BadRequest
(
"REFUND_AMOUNT_EXCEEDED"
,
"refund amount exceeds recharge"
)
}
// Full refund: use actual pay_amount for gateway (includes fees)
ga
:=
amt
if
math
.
Abs
(
amt
-
o
.
Amount
)
<=
amountToleranceCNY
{
ga
=
o
.
PayAmount
}
ga
:=
calculateGatewayRefundAmount
(
o
.
Amount
,
o
.
PayAmount
,
amt
)
rr
:=
strings
.
TrimSpace
(
reason
)
if
rr
==
""
&&
o
.
RefundRequestReason
!=
nil
{
rr
=
*
o
.
RefundRequestReason
...
...
@@ -150,11 +181,17 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref
if
!
s
.
hasAuditLog
(
ctx
,
p
.
OrderID
,
"REFUND_ROLLBACK_FAILED"
)
{
_
,
err
:=
s
.
subscriptionSvc
.
ExtendSubscription
(
ctx
,
p
.
SubscriptionID
,
-
p
.
SubDaysToDeduct
)
if
err
!=
nil
{
// If deducting would expire the subscription, revoke it entirely
slog
.
Info
(
"subscription deduction would expire, revoking"
,
"orderID"
,
p
.
OrderID
,
"subID"
,
p
.
SubscriptionID
,
"days"
,
p
.
SubDaysToDeduct
)
if
revokeErr
:=
s
.
subscriptionSvc
.
RevokeSubscription
(
ctx
,
p
.
SubscriptionID
);
revokeErr
!=
nil
{
if
errors
.
Is
(
err
,
ErrAdjustWouldExpire
)
{
// Deduction would expire the subscription — revoke it entirely
slog
.
Info
(
"subscription deduction would expire, revoking"
,
"orderID"
,
p
.
OrderID
,
"subID"
,
p
.
SubscriptionID
,
"days"
,
p
.
SubDaysToDeduct
)
if
revokeErr
:=
s
.
subscriptionSvc
.
RevokeSubscription
(
ctx
,
p
.
SubscriptionID
);
revokeErr
!=
nil
{
s
.
restoreStatus
(
ctx
,
p
)
return
nil
,
fmt
.
Errorf
(
"revoke subscription: %w"
,
revokeErr
)
}
}
else
{
// Other errors (DB failure, not found) — abort refund
s
.
restoreStatus
(
ctx
,
p
)
return
nil
,
fmt
.
Errorf
(
"
revoke
subscription: %w"
,
revokeE
rr
)
return
nil
,
fmt
.
Errorf
(
"
deduct
subscription
days
: %w"
,
e
rr
)
}
}
}
else
{
...
...
backend/internal/service/payment_service.go
View file @
0b746501
...
...
@@ -288,8 +288,6 @@ func psComputeValidityDays(days int, unit string) int {
}
}
func
(
s
*
PaymentService
)
getFeeRate
(
_
string
)
float64
{
return
0
}
func
psStartOfDayUTC
(
t
time
.
Time
)
time
.
Time
{
y
,
m
,
d
:=
t
.
UTC
()
.
Date
()
return
time
.
Date
(
y
,
m
,
d
,
0
,
0
,
0
,
0
,
time
.
UTC
)
...
...
backend/internal/service/ratelimit_service_401_test.go
View file @
0b746501
...
...
@@ -102,6 +102,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t
})
}
// TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError
// OpenAI OAuth 401 缓存失效出错时仍走 temp_unschedulable
func
TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
invalidator
:=
&
tokenCacheInvalidatorRecorder
{
err
:
errors
.
New
(
"boom"
)}
...
...
@@ -109,7 +111,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
service
.
SetTokenCacheInvalidator
(
invalidator
)
account
:=
&
Account
{
ID
:
101
,
Platform
:
Platform
Gemini
,
Platform
:
Platform
OpenAI
,
Type
:
AccountTypeOAuth
,
}
...
...
backend/internal/service/setting_service.go
View file @
0b746501
...
...
@@ -99,13 +99,19 @@ type DefaultSubscriptionGroupReader interface {
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
}
// WebSearchManagerBuilder creates a websearch.Manager from config (injected by infra layer).
// proxyURLs maps proxy ID to resolved URL for provider-level proxy support.
type
WebSearchManagerBuilder
func
(
cfg
*
WebSearchEmulationConfig
,
proxyURLs
map
[
int64
]
string
)
// SettingService 系统设置服务
type
SettingService
struct
{
settingRepo
SettingRepository
defaultSubGroupReader
DefaultSubscriptionGroupReader
cfg
*
config
.
Config
onUpdate
func
()
// Callback when settings are updated (for cache invalidation)
version
string
// Application version
settingRepo
SettingRepository
defaultSubGroupReader
DefaultSubscriptionGroupReader
proxyRepo
ProxyRepository
// for resolving websearch provider proxy URLs
cfg
*
config
.
Config
onUpdate
func
()
// Callback when settings are updated (for cache invalidation)
version
string
// Application version
webSearchManagerBuilder
WebSearchManagerBuilder
}
// NewSettingService 创建系统设置服务实例
...
...
@@ -121,6 +127,11 @@ func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscri
s
.
defaultSubGroupReader
=
reader
}
// SetProxyRepository injects a proxy repo for resolving websearch provider proxy URLs.
func
(
s
*
SettingService
)
SetProxyRepository
(
repo
ProxyRepository
)
{
s
.
proxyRepo
=
repo
}
// GetAllSettings 获取所有系统设置
func
(
s
*
SettingService
)
GetAllSettings
(
ctx
context
.
Context
)
(
*
SystemSettings
,
error
)
{
settings
,
err
:=
s
.
settingRepo
.
GetAll
(
ctx
)
...
...
@@ -168,9 +179,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyCustomEndpoints
,
SettingKeyLinuxDoConnectEnabled
,
SettingKeyBackendModeEnabled
,
SettingPaymentEnabled
,
SettingKeyOIDCConnectEnabled
,
SettingKeyOIDCConnectProviderName
,
SettingPaymentEnabled
,
SettingKeyBalanceLowNotifyEnabled
,
SettingKeyBalanceLowNotifyThreshold
,
SettingKeyBalanceLowNotifyRechargeURL
,
SettingKeyAccountQuotaNotifyEnabled
,
}
settings
,
err
:=
s
.
settingRepo
.
GetMultiple
(
ctx
,
keys
)
...
...
@@ -209,6 +224,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
settings
[
SettingKeyTablePageSizeOptions
],
)
var
balanceLowNotifyThreshold
float64
if
v
,
err
:=
strconv
.
ParseFloat
(
settings
[
SettingKeyBalanceLowNotifyThreshold
],
64
);
err
==
nil
&&
v
>=
0
{
balanceLowNotifyThreshold
=
v
}
return
&
PublicSettings
{
RegistrationEnabled
:
settings
[
SettingKeyRegistrationEnabled
]
==
"true"
,
EmailVerifyEnabled
:
emailVerifyEnabled
,
...
...
@@ -235,9 +255,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomEndpoints
:
settings
[
SettingKeyCustomEndpoints
],
LinuxDoOAuthEnabled
:
linuxDoEnabled
,
BackendModeEnabled
:
settings
[
SettingKeyBackendModeEnabled
]
==
"true"
,
PaymentEnabled
:
settings
[
SettingPaymentEnabled
]
==
"true"
,
OIDCOAuthEnabled
:
oidcEnabled
,
OIDCOAuthProviderName
:
oidcProviderName
,
PaymentEnabled
:
settings
[
SettingPaymentEnabled
]
==
"true"
,
BalanceLowNotifyEnabled
:
settings
[
SettingKeyBalanceLowNotifyEnabled
]
==
"true"
,
AccountQuotaNotifyEnabled
:
settings
[
SettingKeyAccountQuotaNotifyEnabled
]
==
"true"
,
BalanceLowNotifyThreshold
:
balanceLowNotifyThreshold
,
BalanceLowNotifyRechargeURL
:
settings
[
SettingKeyBalanceLowNotifyRechargeURL
],
},
nil
}
...
...
@@ -287,10 +311,14 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints
json
.
RawMessage
`json:"custom_endpoints"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
BackendModeEnabled
bool
`json:"backend_mode_enabled"`
PaymentEnabled
bool
`json:"payment_enabled"`
OIDCOAuthEnabled
bool
`json:"oidc_oauth_enabled"`
OIDCOAuthProviderName
string
`json:"oidc_oauth_provider_name"`
PaymentEnabled
bool
`json:"payment_enabled"`
Version
string
`json:"version,omitempty"`
BalanceLowNotifyEnabled
bool
`json:"balance_low_notify_enabled"`
AccountQuotaNotifyEnabled
bool
`json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold
float64
`json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL
string
`json:"balance_low_notify_recharge_url"`
}{
RegistrationEnabled
:
settings
.
RegistrationEnabled
,
EmailVerifyEnabled
:
settings
.
EmailVerifyEnabled
,
...
...
@@ -317,10 +345,14 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints
:
safeRawJSONArray
(
settings
.
CustomEndpoints
),
LinuxDoOAuthEnabled
:
settings
.
LinuxDoOAuthEnabled
,
BackendModeEnabled
:
settings
.
BackendModeEnabled
,
PaymentEnabled
:
settings
.
PaymentEnabled
,
OIDCOAuthEnabled
:
settings
.
OIDCOAuthEnabled
,
OIDCOAuthProviderName
:
settings
.
OIDCOAuthProviderName
,
PaymentEnabled
:
settings
.
PaymentEnabled
,
Version
:
s
.
version
,
BalanceLowNotifyEnabled
:
settings
.
BalanceLowNotifyEnabled
,
AccountQuotaNotifyEnabled
:
settings
.
AccountQuotaNotifyEnabled
,
BalanceLowNotifyThreshold
:
settings
.
BalanceLowNotifyThreshold
,
BalanceLowNotifyRechargeURL
:
settings
.
BalanceLowNotifyRechargeURL
,
},
nil
}
...
...
@@ -595,6 +627,13 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates
[
SettingKeyEnableMetadataPassthrough
]
=
strconv
.
FormatBool
(
settings
.
EnableMetadataPassthrough
)
updates
[
SettingKeyEnableCCHSigning
]
=
strconv
.
FormatBool
(
settings
.
EnableCCHSigning
)
// Balance low notification
updates
[
SettingKeyBalanceLowNotifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
BalanceLowNotifyEnabled
)
updates
[
SettingKeyBalanceLowNotifyThreshold
]
=
strconv
.
FormatFloat
(
settings
.
BalanceLowNotifyThreshold
,
'f'
,
8
,
64
)
updates
[
SettingKeyBalanceLowNotifyRechargeURL
]
=
settings
.
BalanceLowNotifyRechargeURL
updates
[
SettingKeyAccountQuotaNotifyEnabled
]
=
strconv
.
FormatBool
(
settings
.
AccountQuotaNotifyEnabled
)
updates
[
SettingKeyAccountQuotaNotifyEmails
]
=
MarshalNotifyEmails
(
settings
.
AccountQuotaNotifyEmails
)
err
=
s
.
settingRepo
.
SetMultiple
(
ctx
,
updates
)
if
err
==
nil
{
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
...
...
@@ -1217,6 +1256,30 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result
.
EnableMetadataPassthrough
=
settings
[
SettingKeyEnableMetadataPassthrough
]
==
"true"
result
.
EnableCCHSigning
=
settings
[
SettingKeyEnableCCHSigning
]
==
"true"
// Web search emulation: quick enabled check from the JSON config
if
raw
:=
settings
[
SettingKeyWebSearchEmulationConfig
];
raw
!=
""
{
var
wsCfg
WebSearchEmulationConfig
if
err
:=
json
.
Unmarshal
([]
byte
(
raw
),
&
wsCfg
);
err
==
nil
{
result
.
WebSearchEmulationEnabled
=
wsCfg
.
Enabled
&&
len
(
wsCfg
.
Providers
)
>
0
}
}
// Balance low notification
result
.
BalanceLowNotifyEnabled
=
settings
[
SettingKeyBalanceLowNotifyEnabled
]
==
"true"
if
v
,
err
:=
strconv
.
ParseFloat
(
settings
[
SettingKeyBalanceLowNotifyThreshold
],
64
);
err
==
nil
&&
v
>=
0
{
result
.
BalanceLowNotifyThreshold
=
v
}
result
.
BalanceLowNotifyRechargeURL
=
settings
[
SettingKeyBalanceLowNotifyRechargeURL
]
// Account quota notification
result
.
AccountQuotaNotifyEnabled
=
settings
[
SettingKeyAccountQuotaNotifyEnabled
]
==
"true"
if
raw
:=
strings
.
TrimSpace
(
settings
[
SettingKeyAccountQuotaNotifyEmails
]);
raw
!=
""
{
result
.
AccountQuotaNotifyEmails
=
ParseNotifyEmails
(
raw
)
}
if
result
.
AccountQuotaNotifyEmails
==
nil
{
result
.
AccountQuotaNotifyEmails
=
[]
NotifyEmailEntry
{}
}
return
result
}
...
...
backend/internal/service/setting_service_public_test.go
View file @
0b746501
...
...
@@ -66,7 +66,7 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis
func
TestSettingService_GetPublicSettings_ExposesTablePreferences
(
t
*
testing
.
T
)
{
repo
:=
&
settingPublicRepoStub
{
values
:
map
[
string
]
string
{
SettingKeyTableDefaultPageSize
:
"50"
,
SettingKeyTableDefaultPageSize
:
"50"
,
SettingKeyTablePageSizeOptions
:
"[20,50,100]"
,
},
}
...
...
backend/internal/service/setting_service_update_test.go
View file @
0b746501
...
...
@@ -208,7 +208,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
svc
:=
NewSettingService
(
repo
,
&
config
.
Config
{})
err
:=
svc
.
UpdateSettings
(
context
.
Background
(),
&
SystemSettings
{
TableDefaultPageSize
:
50
,
TableDefaultPageSize
:
50
,
TablePageSizeOptions
:
[]
int
{
20
,
50
,
100
},
})
require
.
NoError
(
t
,
err
)
...
...
@@ -216,7 +216,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
require
.
Equal
(
t
,
"[20,50,100]"
,
repo
.
updates
[
SettingKeyTablePageSizeOptions
])
err
=
svc
.
UpdateSettings
(
context
.
Background
(),
&
SystemSettings
{
TableDefaultPageSize
:
1000
,
TableDefaultPageSize
:
1000
,
TablePageSizeOptions
:
[]
int
{
20
,
100
},
})
require
.
NoError
(
t
,
err
)
...
...
backend/internal/service/settings_view.go
View file @
0b746501
...
...
@@ -106,6 +106,18 @@ type SystemSettings struct {
EnableFingerprintUnification
bool
// 是否统一 OAuth 账号的指纹头(默认 true)
EnableMetadataPassthrough
bool
// 是否透传客户端原始 metadata(默认 false)
EnableCCHSigning
bool
// 是否对 billing header cch 进行签名(默认 false)
// Web Search Emulation
WebSearchEmulationEnabled
bool
// 是否启用 web search 模拟
// Balance low notification
BalanceLowNotifyEnabled
bool
BalanceLowNotifyThreshold
float64
BalanceLowNotifyRechargeURL
string
// Account quota notification
AccountQuotaNotifyEnabled
bool
AccountQuotaNotifyEmails
[]
NotifyEmailEntry
}
type
DefaultSubscriptionSetting
struct
{
...
...
@@ -141,10 +153,15 @@ type PublicSettings struct {
LinuxDoOAuthEnabled
bool
BackendModeEnabled
bool
PaymentEnabled
bool
OIDCOAuthEnabled
bool
OIDCOAuthProviderName
string
PaymentEnabled
bool
Version
string
BalanceLowNotifyEnabled
bool
AccountQuotaNotifyEnabled
bool
BalanceLowNotifyThreshold
float64
BalanceLowNotifyRechargeURL
string
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
...
...
backend/internal/service/usage_billing.go
View file @
0b746501
...
...
@@ -100,9 +100,22 @@ func valueOrZero(v *int64) int64 {
return
*
v
}
// AccountQuotaState holds the post-increment quota state returned by the DB transaction.
// All values are post-update (i.e., already include the increment).
type
AccountQuotaState
struct
{
TotalUsed
float64
TotalLimit
float64
DailyUsed
float64
DailyLimit
float64
WeeklyUsed
float64
WeeklyLimit
float64
}
type
UsageBillingApplyResult
struct
{
Applied
bool
APIKeyQuotaExhausted
bool
NewBalance
*
float64
// post-deduction balance (nil = no balance deduction)
QuotaState
*
AccountQuotaState
// post-increment quota state (nil = no quota increment)
}
type
UsageBillingRepository
interface
{
...
...
backend/internal/service/usage_log.go
View file @
0b746501
...
...
@@ -146,6 +146,8 @@ type UsageLog struct {
RateMultiplier
float64
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
AccountRateMultiplier
*
float64
// AccountStatsCost 账号统计定价预计算费用(nil = 使用默认公式 total_cost × account_rate_multiplier)
AccountStatsCost
*
float64
BillingType
int8
RequestType
RequestType
...
...
backend/internal/service/user.go
View file @
0b746501
...
...
@@ -30,6 +30,13 @@ type User struct {
TotpEnabled
bool
// 是否启用 TOTP
TotpEnabledAt
*
time
.
Time
// TOTP 启用时间
// 余额不足通知
BalanceNotifyEnabled
bool
BalanceNotifyThresholdType
string
// "fixed" (default) | "percentage"
BalanceNotifyThreshold
*
float64
BalanceNotifyExtraEmails
[]
NotifyEmailEntry
TotalRecharged
float64
APIKeys
[]
APIKey
Subscriptions
[]
UserSubscription
}
...
...
backend/internal/service/user_service.go
View file @
0b746501
...
...
@@ -2,8 +2,10 @@ package service
import
(
"context"
"crypto/subtle"
"fmt"
"log"
"log/slog"
"strings"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
...
...
@@ -11,9 +13,18 @@ import (
)
var
(
ErrUserNotFound
=
infraerrors
.
NotFound
(
"USER_NOT_FOUND"
,
"user not found"
)
ErrPasswordIncorrect
=
infraerrors
.
BadRequest
(
"PASSWORD_INCORRECT"
,
"current password is incorrect"
)
ErrInsufficientPerms
=
infraerrors
.
Forbidden
(
"INSUFFICIENT_PERMISSIONS"
,
"insufficient permissions"
)
ErrUserNotFound
=
infraerrors
.
NotFound
(
"USER_NOT_FOUND"
,
"user not found"
)
ErrPasswordIncorrect
=
infraerrors
.
BadRequest
(
"PASSWORD_INCORRECT"
,
"current password is incorrect"
)
ErrInsufficientPerms
=
infraerrors
.
Forbidden
(
"INSUFFICIENT_PERMISSIONS"
,
"insufficient permissions"
)
ErrNotifyCodeUserRateLimit
=
infraerrors
.
TooManyRequests
(
"NOTIFY_CODE_USER_RATE_LIMIT"
,
"too many verification codes requested, please try again later"
)
)
const
(
maxNotifyEmails
=
3
// Maximum number of notification emails per user
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit
=
5
notifyCodeUserRateWindow
=
10
*
time
.
Minute
)
// UserListFilters contains all filter options for listing users
...
...
@@ -58,9 +69,11 @@ type UserRepository interface {
// UpdateProfileRequest 更新用户资料请求
type
UpdateProfileRequest
struct
{
Email
*
string
`json:"email"`
Username
*
string
`json:"username"`
Concurrency
*
int
`json:"concurrency"`
Email
*
string
`json:"email"`
Username
*
string
`json:"username"`
Concurrency
*
int
`json:"concurrency"`
BalanceNotifyEnabled
*
bool
`json:"balance_notify_enabled"`
BalanceNotifyThreshold
*
float64
`json:"balance_notify_threshold"`
}
// ChangePasswordRequest 修改密码请求
...
...
@@ -72,14 +85,16 @@ type ChangePasswordRequest struct {
// UserService 用户服务
type
UserService
struct
{
userRepo
UserRepository
settingRepo
SettingRepository
authCacheInvalidator
APIKeyAuthCacheInvalidator
billingCache
BillingCache
}
// NewUserService 创建用户服务实例
func
NewUserService
(
userRepo
UserRepository
,
authCacheInvalidator
APIKeyAuthCacheInvalidator
,
billingCache
BillingCache
)
*
UserService
{
func
NewUserService
(
userRepo
UserRepository
,
settingRepo
SettingRepository
,
authCacheInvalidator
APIKeyAuthCacheInvalidator
,
billingCache
BillingCache
)
*
UserService
{
return
&
UserService
{
userRepo
:
userRepo
,
settingRepo
:
settingRepo
,
authCacheInvalidator
:
authCacheInvalidator
,
billingCache
:
billingCache
,
}
...
...
@@ -132,6 +147,17 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user
.
Concurrency
=
*
req
.
Concurrency
}
if
req
.
BalanceNotifyEnabled
!=
nil
{
user
.
BalanceNotifyEnabled
=
*
req
.
BalanceNotifyEnabled
}
if
req
.
BalanceNotifyThreshold
!=
nil
{
if
*
req
.
BalanceNotifyThreshold
<=
0
{
user
.
BalanceNotifyThreshold
=
nil
// clear to system default
}
else
{
user
.
BalanceNotifyThreshold
=
req
.
BalanceNotifyThreshold
}
}
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"update user: %w"
,
err
)
}
...
...
@@ -198,10 +224,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
}
if
s
.
billingCache
!=
nil
{
go
func
()
{
defer
func
()
{
if
r
:=
recover
();
r
!=
nil
{
slog
.
Error
(
"panic in balance cache invalidation"
,
"user_id"
,
userID
,
"recover"
,
r
)
}
}()
cacheCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
if
err
:=
s
.
billingCache
.
InvalidateUserBalance
(
cacheCtx
,
userID
);
err
!=
nil
{
log
.
Printf
(
"invalidate user balance cache failed
:
user_id
=%d err=%v
"
,
userID
,
err
)
s
log
.
Error
(
"invalidate user balance cache failed
"
,
"
user_id"
,
userID
,
"error"
,
err
)
}
}()
}
...
...
@@ -248,3 +279,229 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
}
return
nil
}
// SendNotifyEmailCode sends a verification code to the extra notification email.
func
(
s
*
UserService
)
SendNotifyEmailCode
(
ctx
context
.
Context
,
userID
int64
,
email
string
,
emailService
*
EmailService
,
cache
EmailCache
)
error
{
if
err
:=
checkNotifyCodeRateLimit
(
ctx
,
cache
,
userID
,
email
);
err
!=
nil
{
return
err
}
code
,
err
:=
emailService
.
GenerateVerifyCode
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"generate code: %w"
,
err
)
}
// Send email first — if SMTP fails, don't write cache or increment counters,
// so the user is not locked out by cooldown/rate-limit for a code they never received.
if
err
:=
s
.
sendNotifyVerifyEmail
(
ctx
,
emailService
,
email
,
code
);
err
!=
nil
{
return
err
}
if
err
:=
saveNotifyVerifyCode
(
ctx
,
cache
,
email
,
code
);
err
!=
nil
{
return
err
}
// Increment user-level counter after successful save
if
_
,
err
:=
cache
.
IncrNotifyCodeUserRate
(
ctx
,
userID
,
notifyCodeUserRateWindow
);
err
!=
nil
{
slog
.
Error
(
"failed to increment notify code user rate"
,
"user_id"
,
userID
,
"error"
,
err
)
}
return
nil
}
// checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit.
func
checkNotifyCodeRateLimit
(
ctx
context
.
Context
,
cache
EmailCache
,
userID
int64
,
email
string
)
error
{
existing
,
err
:=
cache
.
GetNotifyVerifyCode
(
ctx
,
email
)
if
err
==
nil
&&
existing
!=
nil
{
if
time
.
Since
(
existing
.
CreatedAt
)
<
verifyCodeCooldown
{
return
ErrVerifyCodeTooFrequent
}
}
count
,
err
:=
cache
.
GetNotifyCodeUserRate
(
ctx
,
userID
)
if
err
==
nil
&&
count
>=
notifyCodeUserRateLimit
{
return
ErrNotifyCodeUserRateLimit
}
return
nil
}
// saveNotifyVerifyCode saves the verification code to cache.
func
saveNotifyVerifyCode
(
ctx
context
.
Context
,
cache
EmailCache
,
email
,
code
string
)
error
{
data
:=
&
VerificationCodeData
{
Code
:
code
,
Attempts
:
0
,
CreatedAt
:
time
.
Now
(),
ExpiresAt
:
time
.
Now
()
.
Add
(
verifyCodeTTL
),
}
if
err
:=
cache
.
SetNotifyVerifyCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"save verify code: %w"
,
err
)
}
return
nil
}
// sendNotifyVerifyEmail builds and sends the verification email.
func
(
s
*
UserService
)
sendNotifyVerifyEmail
(
ctx
context
.
Context
,
emailService
*
EmailService
,
email
,
code
string
)
error
{
siteName
:=
"Sub2API"
if
s
.
settingRepo
!=
nil
{
if
name
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeySiteName
);
err
==
nil
&&
name
!=
""
{
siteName
=
name
}
}
subject
:=
fmt
.
Sprintf
(
"[%s] 通知邮箱验证码 / Notification Email Verification"
,
siteName
)
body
:=
buildNotifyVerifyEmailBody
(
code
,
siteName
)
return
emailService
.
SendEmail
(
ctx
,
email
,
subject
,
body
)
}
// VerifyAndAddNotifyEmail verifies the code and adds the email to user's extra emails.
func
(
s
*
UserService
)
VerifyAndAddNotifyEmail
(
ctx
context
.
Context
,
userID
int64
,
email
,
code
string
,
cache
EmailCache
)
error
{
if
err
:=
verifyNotifyCode
(
ctx
,
cache
,
email
,
code
);
err
!=
nil
{
return
err
}
_
=
cache
.
DeleteNotifyVerifyCode
(
ctx
,
email
)
return
s
.
addOrVerifyNotifyEmail
(
ctx
,
userID
,
email
)
}
// verifyNotifyCode validates the verification code against the cached data.
func
verifyNotifyCode
(
ctx
context
.
Context
,
cache
EmailCache
,
email
,
code
string
)
error
{
data
,
err
:=
cache
.
GetNotifyVerifyCode
(
ctx
,
email
)
if
err
!=
nil
||
data
==
nil
{
return
ErrInvalidVerifyCode
}
if
data
.
Attempts
>=
maxVerifyCodeAttempts
{
return
ErrVerifyCodeMaxAttempts
}
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Code
),
[]
byte
(
code
))
!=
1
{
data
.
Attempts
++
remaining
:=
time
.
Until
(
data
.
ExpiresAt
)
if
remaining
<=
0
{
return
ErrInvalidVerifyCode
}
if
err
:=
cache
.
SetNotifyVerifyCode
(
ctx
,
email
,
data
,
remaining
);
err
!=
nil
{
slog
.
Error
(
"failed to update notify verify code attempts"
,
"email"
,
email
,
"error"
,
err
)
}
if
data
.
Attempts
>=
maxVerifyCodeAttempts
{
return
ErrVerifyCodeMaxAttempts
}
return
ErrInvalidVerifyCode
}
return
nil
}
// addOrVerifyNotifyEmail adds the email to user's extra notification emails or marks it as verified.
// Note: concurrent calls for the same user could race on the read-modify-write of
// BalanceNotifyExtraEmails. The window is small (requires two verify flows completing
// simultaneously), and the worst case is a duplicate entry which is harmless.
func
(
s
*
UserService
)
addOrVerifyNotifyEmail
(
ctx
context
.
Context
,
userID
int64
,
email
string
)
error
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
err
}
for
i
,
e
:=
range
user
.
BalanceNotifyExtraEmails
{
if
strings
.
EqualFold
(
e
.
Email
,
email
)
{
if
!
e
.
Verified
{
user
.
BalanceNotifyExtraEmails
[
i
]
.
Verified
=
true
return
s
.
userRepo
.
Update
(
ctx
,
user
)
}
return
nil
// Already verified
}
}
if
len
(
user
.
BalanceNotifyExtraEmails
)
>=
maxNotifyEmails
{
return
infraerrors
.
BadRequest
(
"TOO_MANY_NOTIFY_EMAILS"
,
fmt
.
Sprintf
(
"maximum %d notification emails allowed"
,
maxNotifyEmails
))
}
user
.
BalanceNotifyExtraEmails
=
append
(
user
.
BalanceNotifyExtraEmails
,
NotifyEmailEntry
{
Email
:
email
,
Disabled
:
false
,
Verified
:
true
,
})
return
s
.
userRepo
.
Update
(
ctx
,
user
)
}
// RemoveNotifyEmail removes an email from user's extra notification emails.
func
(
s
*
UserService
)
RemoveNotifyEmail
(
ctx
context
.
Context
,
userID
int64
,
email
string
)
error
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
err
}
filtered
:=
make
([]
NotifyEmailEntry
,
0
,
len
(
user
.
BalanceNotifyExtraEmails
))
found
:=
false
for
_
,
e
:=
range
user
.
BalanceNotifyExtraEmails
{
if
strings
.
EqualFold
(
e
.
Email
,
email
)
{
found
=
true
}
else
{
filtered
=
append
(
filtered
,
e
)
}
}
if
!
found
{
return
infraerrors
.
BadRequest
(
"EMAIL_NOT_FOUND"
,
"notification email not found"
)
}
user
.
BalanceNotifyExtraEmails
=
filtered
return
s
.
userRepo
.
Update
(
ctx
,
user
)
}
// ToggleNotifyEmail toggles the disabled state of a notification email entry.
func
(
s
*
UserService
)
ToggleNotifyEmail
(
ctx
context
.
Context
,
userID
int64
,
email
string
,
disabled
bool
)
error
{
user
,
err
:=
s
.
userRepo
.
GetByID
(
ctx
,
userID
)
if
err
!=
nil
{
return
err
}
found
:=
false
for
i
,
e
:=
range
user
.
BalanceNotifyExtraEmails
{
if
strings
.
EqualFold
(
e
.
Email
,
email
)
{
user
.
BalanceNotifyExtraEmails
[
i
]
.
Disabled
=
disabled
found
=
true
break
}
}
if
!
found
{
return
infraerrors
.
BadRequest
(
"EMAIL_NOT_FOUND"
,
"notification email not found"
)
}
return
s
.
userRepo
.
Update
(
ctx
,
user
)
}
// notifyVerifyEmailTemplate is the HTML template for notify email verification.
// Format args: siteName, code.
const
notifyVerifyEmailTemplate
=
`<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.code { font-size: 36px; font-weight: bold; letter-spacing: 8px; color: #333; background-color: #f8f9fa; padding: 20px 30px; border-radius: 8px; display: inline-block; margin: 20px 0; font-family: monospace; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">通知邮箱验证码 / Notification Email Verification</p>
<div class="code">%s</div>
<div class="info">
<p>您正在添加额外的通知邮箱,请输入此验证码完成验证。</p>
<p>You are adding an extra notification email. Please enter this code to verify.</p>
<p>此验证码将在 <strong>15 分钟</strong>后失效。</p>
<p>This code will expire in <strong>15 minutes</strong>.</p>
<p>如果您没有请求此验证码,请忽略此邮件。</p>
<p>If you did not request this code, please ignore this email.</p>
</div>
</div>
<div class="footer">
<p>此邮件由系统自动发送,请勿回复。/ This is an automated message, please do not reply.</p>
</div>
</div>
</body>
</html>`
// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification.
func
buildNotifyVerifyEmailBody
(
code
,
siteName
string
)
string
{
return
fmt
.
Sprintf
(
notifyVerifyEmailTemplate
,
siteName
,
code
)
}
backend/internal/service/user_service_test.go
View file @
0b746501
...
...
@@ -46,12 +46,12 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
return
0
,
nil
}
func
(
m
*
mockUserRepo
)
AddGroupToAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
RemoveGroupFromUserAllowedGroups
(
context
.
Context
,
int64
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
UpdateTotpSecret
(
context
.
Context
,
int64
,
*
string
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
EnableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
func
(
m
*
mockUserRepo
)
DisableTotp
(
context
.
Context
,
int64
)
error
{
return
nil
}
// --- mock: APIKeyAuthCacheInvalidator ---
...
...
@@ -117,7 +117,7 @@ func (m *mockBillingCache) InvalidateAPIKeyRateLimit(context.Context, int64) err
func
TestUpdateBalance_Success
(
t
*
testing
.
T
)
{
repo
:=
&
mockUserRepo
{}
cache
:=
&
mockBillingCache
{}
svc
:=
NewUserService
(
repo
,
nil
,
cache
)
svc
:=
NewUserService
(
repo
,
nil
,
nil
,
cache
)
err
:=
svc
.
UpdateBalance
(
context
.
Background
(),
42
,
100.0
)
require
.
NoError
(
t
,
err
)
...
...
@@ -134,7 +134,7 @@ func TestUpdateBalance_Success(t *testing.T) {
func
TestUpdateBalance_NilBillingCache_NoPanic
(
t
*
testing
.
T
)
{
repo
:=
&
mockUserRepo
{}
svc
:=
NewUserService
(
repo
,
nil
,
nil
)
// billingCache = nil
svc
:=
NewUserService
(
repo
,
nil
,
nil
,
nil
)
// billingCache = nil
err
:=
svc
.
UpdateBalance
(
context
.
Background
(),
1
,
50.0
)
require
.
NoError
(
t
,
err
,
"billingCache 为 nil 时不应 panic"
)
...
...
@@ -143,7 +143,7 @@ func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
func
TestUpdateBalance_CacheFailure_DoesNotAffectReturn
(
t
*
testing
.
T
)
{
repo
:=
&
mockUserRepo
{}
cache
:=
&
mockBillingCache
{
invalidateErr
:
errors
.
New
(
"redis connection refused"
)}
svc
:=
NewUserService
(
repo
,
nil
,
cache
)
svc
:=
NewUserService
(
repo
,
nil
,
nil
,
cache
)
err
:=
svc
.
UpdateBalance
(
context
.
Background
(),
99
,
200.0
)
require
.
NoError
(
t
,
err
,
"缓存失效失败不应影响主流程返回值"
)
...
...
@@ -157,7 +157,7 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
func
TestUpdateBalance_RepoError_ReturnsError
(
t
*
testing
.
T
)
{
repo
:=
&
mockUserRepo
{
updateBalanceErr
:
errors
.
New
(
"database error"
)}
cache
:=
&
mockBillingCache
{}
svc
:=
NewUserService
(
repo
,
nil
,
cache
)
svc
:=
NewUserService
(
repo
,
nil
,
nil
,
cache
)
err
:=
svc
.
UpdateBalance
(
context
.
Background
(),
1
,
100.0
)
require
.
Error
(
t
,
err
,
"repo 失败时应返回错误"
)
...
...
@@ -173,7 +173,7 @@ func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) {
repo
:=
&
mockUserRepo
{}
auth
:=
&
mockAuthCacheInvalidator
{}
cache
:=
&
mockBillingCache
{}
svc
:=
NewUserService
(
repo
,
auth
,
cache
)
svc
:=
NewUserService
(
repo
,
nil
,
auth
,
cache
)
err
:=
svc
.
UpdateBalance
(
context
.
Background
(),
77
,
300.0
)
require
.
NoError
(
t
,
err
)
...
...
@@ -194,7 +194,7 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
auth
:=
&
mockAuthCacheInvalidator
{}
cache
:=
&
mockBillingCache
{}
svc
:=
NewUserService
(
repo
,
auth
,
cache
)
svc
:=
NewUserService
(
repo
,
nil
,
auth
,
cache
)
require
.
NotNil
(
t
,
svc
)
require
.
Equal
(
t
,
repo
,
svc
.
userRepo
)
require
.
Equal
(
t
,
auth
,
svc
.
authCacheInvalidator
)
...
...
backend/internal/service/websearch_config.go
0 → 100644
View file @
0b746501
package
service
import
(
"context"
"encoding/json"
"fmt"
"log/slog"
"sync/atomic"
"time"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"golang.org/x/sync/singleflight"
)
// WebSearchEmulationConfig holds the global web search emulation configuration.
type
WebSearchEmulationConfig
struct
{
Enabled
bool
`json:"enabled"`
Providers
[]
WebSearchProviderConfig
`json:"providers"`
}
// WebSearchProviderConfig describes a single search provider (Brave or Tavily).
type
WebSearchProviderConfig
struct
{
Type
string
`json:"type"`
// websearch.ProviderTypeBrave | Tavily
APIKey
string
`json:"api_key,omitempty"`
// secret — omitted in API responses
APIKeyConfigured
bool
`json:"api_key_configured"`
// read-only mask
QuotaLimit
*
int64
`json:"quota_limit"`
// nil = unlimited, >0 = limited
SubscribedAt
*
int64
`json:"subscribed_at,omitempty"`
// subscription start (unix seconds); quota resets monthly
QuotaUsed
int64
`json:"quota_used,omitempty"`
// read-only: current usage from Redis
ProxyID
*
int64
`json:"proxy_id"`
// optional proxy association
ExpiresAt
*
int64
`json:"expires_at,omitempty"`
// optional expiration timestamp
}
// --- Validation ---
const
maxWebSearchProviders
=
10
var
validProviderTypes
=
map
[
string
]
bool
{
websearch
.
ProviderTypeBrave
:
true
,
websearch
.
ProviderTypeTavily
:
true
,
}
func
validateWebSearchConfig
(
cfg
*
WebSearchEmulationConfig
)
error
{
if
cfg
==
nil
{
return
nil
}
if
len
(
cfg
.
Providers
)
>
maxWebSearchProviders
{
return
fmt
.
Errorf
(
"too many providers (max %d)"
,
maxWebSearchProviders
)
}
seen
:=
make
(
map
[
string
]
bool
,
len
(
cfg
.
Providers
))
for
i
,
p
:=
range
cfg
.
Providers
{
if
!
validProviderTypes
[
p
.
Type
]
{
return
fmt
.
Errorf
(
"provider[%d]: invalid type %q"
,
i
,
p
.
Type
)
}
if
p
.
QuotaLimit
!=
nil
&&
*
p
.
QuotaLimit
<
0
{
return
fmt
.
Errorf
(
"provider[%d]: quota_limit must be > 0 or null"
,
i
)
}
if
seen
[
p
.
Type
]
{
return
fmt
.
Errorf
(
"provider[%d]: duplicate type %q"
,
i
,
p
.
Type
)
}
seen
[
p
.
Type
]
=
true
}
return
nil
}
// --- In-process cache (same pattern as gateway forwarding settings) ---
const
sfKeyWebSearchConfig
=
"web_search_emulation_config"
type
cachedWebSearchEmulationConfig
struct
{
config
*
WebSearchEmulationConfig
expiresAt
int64
// unix nano
}
var
webSearchEmulationCache
atomic
.
Value
// *cachedWebSearchEmulationConfig
var
webSearchEmulationSF
singleflight
.
Group
const
(
webSearchEmulationCacheTTL
=
60
*
time
.
Second
webSearchEmulationErrorTTL
=
5
*
time
.
Second
webSearchEmulationDBTimeout
=
5
*
time
.
Second
)
// GetWebSearchEmulationConfig returns the configuration with in-process cache + singleflight.
func
(
s
*
SettingService
)
GetWebSearchEmulationConfig
(
ctx
context
.
Context
)
(
*
WebSearchEmulationConfig
,
error
)
{
if
cached
:=
webSearchEmulationCache
.
Load
();
cached
!=
nil
{
if
c
,
ok
:=
cached
.
(
*
cachedWebSearchEmulationConfig
);
ok
&&
time
.
Now
()
.
UnixNano
()
<
c
.
expiresAt
{
return
c
.
config
,
nil
}
}
result
,
err
,
_
:=
webSearchEmulationSF
.
Do
(
sfKeyWebSearchConfig
,
func
()
(
any
,
error
)
{
return
s
.
loadWebSearchConfigFromDB
()
})
if
err
!=
nil
{
return
&
WebSearchEmulationConfig
{},
err
}
if
cfg
,
ok
:=
result
.
(
*
WebSearchEmulationConfig
);
ok
{
return
cfg
,
nil
}
return
&
WebSearchEmulationConfig
{},
nil
}
func
(
s
*
SettingService
)
loadWebSearchConfigFromDB
()
(
*
WebSearchEmulationConfig
,
error
)
{
dbCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
webSearchEmulationDBTimeout
)
defer
cancel
()
raw
,
err
:=
s
.
settingRepo
.
GetValue
(
dbCtx
,
SettingKeyWebSearchEmulationConfig
)
if
err
!=
nil
{
webSearchEmulationCache
.
Store
(
&
cachedWebSearchEmulationConfig
{
config
:
&
WebSearchEmulationConfig
{},
expiresAt
:
time
.
Now
()
.
Add
(
webSearchEmulationErrorTTL
)
.
UnixNano
(),
})
return
&
WebSearchEmulationConfig
{},
err
}
cfg
:=
parseWebSearchConfigJSON
(
raw
)
webSearchEmulationCache
.
Store
(
&
cachedWebSearchEmulationConfig
{
config
:
cfg
,
expiresAt
:
time
.
Now
()
.
Add
(
webSearchEmulationCacheTTL
)
.
UnixNano
(),
})
return
cfg
,
nil
}
func
parseWebSearchConfigJSON
(
raw
string
)
*
WebSearchEmulationConfig
{
cfg
:=
&
WebSearchEmulationConfig
{}
if
raw
==
""
{
return
cfg
}
if
err
:=
json
.
Unmarshal
([]
byte
(
raw
),
cfg
);
err
!=
nil
{
slog
.
Warn
(
"websearch: failed to parse config JSON"
,
"error"
,
err
)
return
&
WebSearchEmulationConfig
{}
}
return
cfg
}
// SaveWebSearchEmulationConfig validates and persists the configuration.
// Empty API keys in the input are preserved from the existing config.
func
(
s
*
SettingService
)
SaveWebSearchEmulationConfig
(
ctx
context
.
Context
,
cfg
*
WebSearchEmulationConfig
)
error
{
if
err
:=
validateWebSearchConfig
(
cfg
);
err
!=
nil
{
return
infraerrors
.
BadRequest
(
"INVALID_WEB_SEARCH_CONFIG"
,
err
.
Error
())
}
s
.
mergeExistingAPIKeys
(
ctx
,
cfg
)
// After merge, validate all enabled providers have API keys
if
cfg
.
Enabled
{
for
_
,
p
:=
range
cfg
.
Providers
{
if
p
.
APIKey
==
""
{
return
infraerrors
.
BadRequest
(
"MISSING_API_KEY"
,
fmt
.
Sprintf
(
"provider %s has no API key configured"
,
p
.
Type
))
}
}
}
data
,
err
:=
json
.
Marshal
(
cfg
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"websearch: marshal config: %w"
,
err
)
}
if
err
:=
s
.
settingRepo
.
Set
(
ctx
,
SettingKeyWebSearchEmulationConfig
,
string
(
data
));
err
!=
nil
{
return
fmt
.
Errorf
(
"websearch: save config: %w"
,
err
)
}
// Invalidate: forget singleflight first, then store new value
webSearchEmulationSF
.
Forget
(
sfKeyWebSearchConfig
)
webSearchEmulationCache
.
Store
(
&
cachedWebSearchEmulationConfig
{
config
:
cfg
,
expiresAt
:
time
.
Now
()
.
Add
(
webSearchEmulationCacheTTL
)
.
UnixNano
(),
})
// Hot-reload: rebuild the global Manager with new config
s
.
rebuildWebSearchManager
(
ctx
)
return
nil
}
// mergeExistingAPIKeys preserves API keys from the current config when incoming value is empty.
func
(
s
*
SettingService
)
mergeExistingAPIKeys
(
ctx
context
.
Context
,
cfg
*
WebSearchEmulationConfig
)
{
existing
,
_
:=
s
.
getWebSearchEmulationConfigRaw
(
ctx
)
if
existing
==
nil
||
cfg
==
nil
{
return
}
existingByType
:=
make
(
map
[
string
]
string
,
len
(
existing
.
Providers
))
for
_
,
p
:=
range
existing
.
Providers
{
if
p
.
APIKey
!=
""
{
existingByType
[
p
.
Type
]
=
p
.
APIKey
}
}
for
i
:=
range
cfg
.
Providers
{
if
cfg
.
Providers
[
i
]
.
APIKey
==
""
{
if
key
,
ok
:=
existingByType
[
cfg
.
Providers
[
i
]
.
Type
];
ok
{
cfg
.
Providers
[
i
]
.
APIKey
=
key
}
}
}
}
func
(
s
*
SettingService
)
getWebSearchEmulationConfigRaw
(
ctx
context
.
Context
)
(
*
WebSearchEmulationConfig
,
error
)
{
raw
,
err
:=
s
.
settingRepo
.
GetValue
(
ctx
,
SettingKeyWebSearchEmulationConfig
)
if
err
!=
nil
{
return
nil
,
err
}
return
parseWebSearchConfigJSON
(
raw
),
nil
}
// IsWebSearchEmulationEnabled is a quick check for whether the global switch is on.
func
(
s
*
SettingService
)
IsWebSearchEmulationEnabled
(
ctx
context
.
Context
)
bool
{
cfg
,
err
:=
s
.
GetWebSearchEmulationConfig
(
ctx
)
if
err
!=
nil
{
return
false
}
return
cfg
.
Enabled
&&
len
(
cfg
.
Providers
)
>
0
}
// SetWebSearchManagerBuilder injects a callback that creates and wires a websearch.Manager.
// The infra layer (main/wire) provides this builder, keeping redis out of the service layer.
// Triggers initial build.
func
(
s
*
SettingService
)
SetWebSearchManagerBuilder
(
ctx
context
.
Context
,
builder
WebSearchManagerBuilder
)
{
s
.
webSearchManagerBuilder
=
builder
s
.
rebuildWebSearchManager
(
ctx
)
}
// rebuildWebSearchManager reads the current config, resolves proxy URLs, and invokes the builder.
func
(
s
*
SettingService
)
rebuildWebSearchManager
(
ctx
context
.
Context
)
{
if
s
.
webSearchManagerBuilder
==
nil
{
return
}
cfg
,
err
:=
s
.
GetWebSearchEmulationConfig
(
ctx
)
if
err
!=
nil
{
SetWebSearchManager
(
nil
)
return
}
proxyURLs
:=
s
.
resolveProviderProxyURLs
(
ctx
,
cfg
)
s
.
webSearchManagerBuilder
(
cfg
,
proxyURLs
)
}
// resolveProviderProxyURLs collects proxy IDs from providers and resolves them to URLs.
func
(
s
*
SettingService
)
resolveProviderProxyURLs
(
ctx
context
.
Context
,
cfg
*
WebSearchEmulationConfig
)
map
[
int64
]
string
{
if
cfg
==
nil
||
s
.
proxyRepo
==
nil
{
return
nil
}
var
ids
[]
int64
for
_
,
p
:=
range
cfg
.
Providers
{
if
p
.
ProxyID
!=
nil
&&
*
p
.
ProxyID
>
0
{
ids
=
append
(
ids
,
*
p
.
ProxyID
)
}
}
if
len
(
ids
)
==
0
{
return
nil
}
proxies
,
err
:=
s
.
proxyRepo
.
ListByIDs
(
ctx
,
ids
)
if
err
!=
nil
{
slog
.
Warn
(
"websearch: failed to resolve proxy URLs"
,
"error"
,
err
)
return
nil
}
result
:=
make
(
map
[
int64
]
string
,
len
(
proxies
))
for
_
,
px
:=
range
proxies
{
result
[
px
.
ID
]
=
px
.
URL
()
}
return
result
}
// WebSearchTestResult holds the result of a search test.
type
WebSearchTestResult
struct
{
Provider
string
`json:"provider"`
Results
[]
websearch
.
SearchResult
`json:"results"`
Query
string
`json:"query"`
}
// TestWebSearch executes a test search using the currently configured Manager.
// Uses Manager.TestSearch which bypasses quota tracking.
const
testSearchTimeout
=
15
*
time
.
Second
func
TestWebSearch
(
ctx
context
.
Context
,
query
string
)
(
*
WebSearchTestResult
,
error
)
{
mgr
:=
getWebSearchManager
()
if
mgr
==
nil
{
return
nil
,
fmt
.
Errorf
(
"web search: manager not initialized, save config first"
)
}
testCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
testSearchTimeout
)
defer
cancel
()
resp
,
providerName
,
err
:=
mgr
.
TestSearch
(
testCtx
,
websearch
.
SearchRequest
{
Query
:
query
,
MaxResults
:
webSearchDefaultMaxResults
,
})
if
err
!=
nil
{
return
nil
,
err
}
return
&
WebSearchTestResult
{
Provider
:
providerName
,
Results
:
resp
.
Results
,
Query
:
resp
.
Query
,
},
nil
}
// PopulateWebSearchUsage returns a copy with quota usage populated from Redis (api_key kept as-is).
func
PopulateWebSearchUsage
(
ctx
context
.
Context
,
cfg
*
WebSearchEmulationConfig
)
*
WebSearchEmulationConfig
{
if
cfg
==
nil
{
return
nil
}
out
:=
*
cfg
out
.
Providers
=
make
([]
WebSearchProviderConfig
,
len
(
cfg
.
Providers
))
mgr
:=
getWebSearchManager
()
for
i
,
p
:=
range
cfg
.
Providers
{
out
.
Providers
[
i
]
=
p
out
.
Providers
[
i
]
.
APIKeyConfigured
=
p
.
APIKey
!=
""
if
mgr
!=
nil
{
used
,
_
:=
mgr
.
GetUsage
(
ctx
,
p
.
Type
)
out
.
Providers
[
i
]
.
QuotaUsed
=
used
}
}
return
&
out
}
// ResetWebSearchUsage deletes the Redis quota key for the given provider type.
func
ResetWebSearchUsage
(
ctx
context
.
Context
,
providerType
string
)
error
{
mgr
:=
getWebSearchManager
()
if
mgr
==
nil
{
return
fmt
.
Errorf
(
"web search manager not initialized"
)
}
return
mgr
.
ResetUsage
(
ctx
,
providerType
)
}
// SanitizeWebSearchConfig returns a copy with api_key fields masked and quota usage populated.
func
SanitizeWebSearchConfig
(
ctx
context
.
Context
,
cfg
*
WebSearchEmulationConfig
)
*
WebSearchEmulationConfig
{
if
cfg
==
nil
{
return
nil
}
out
:=
*
cfg
out
.
Providers
=
make
([]
WebSearchProviderConfig
,
len
(
cfg
.
Providers
))
// Load usage from the global Manager (reads from Redis)
mgr
:=
getWebSearchManager
()
for
i
,
p
:=
range
cfg
.
Providers
{
out
.
Providers
[
i
]
=
p
out
.
Providers
[
i
]
.
APIKeyConfigured
=
p
.
APIKey
!=
""
out
.
Providers
[
i
]
.
APIKey
=
""
// never return the secret
// Populate quota usage from Redis
if
mgr
!=
nil
{
used
,
_
:=
mgr
.
GetUsage
(
ctx
,
p
.
Type
)
out
.
Providers
[
i
]
.
QuotaUsed
=
used
}
}
return
&
out
}
backend/internal/service/websearch_config_test.go
0 → 100644
View file @
0b746501
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/stretchr/testify/require"
)
// --- validateWebSearchConfig ---
func
TestValidateWebSearchConfig_Nil
(
t
*
testing
.
T
)
{
require
.
NoError
(
t
,
validateWebSearchConfig
(
nil
))
}
func
TestValidateWebSearchConfig_Valid
(
t
*
testing
.
T
)
{
cfg
:=
&
WebSearchEmulationConfig
{
Enabled
:
true
,
Providers
:
[]
WebSearchProviderConfig
{
{
Type
:
"brave"
,
QuotaLimit
:
int64Ptr
(
1000
)},
{
Type
:
"tavily"
,
QuotaLimit
:
int64Ptr
(
500
)},
},
}
require
.
NoError
(
t
,
validateWebSearchConfig
(
cfg
))
}
func
TestValidateWebSearchConfig_TooManyProviders
(
t
*
testing
.
T
)
{
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
make
([]
WebSearchProviderConfig
,
11
)}
for
i
:=
range
cfg
.
Providers
{
cfg
.
Providers
[
i
]
=
WebSearchProviderConfig
{
Type
:
"brave"
}
}
err
:=
validateWebSearchConfig
(
cfg
)
require
.
ErrorContains
(
t
,
err
,
"too many providers"
)
}
func
TestValidateWebSearchConfig_InvalidType
(
t
*
testing
.
T
)
{
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"bing"
}},
}
require
.
ErrorContains
(
t
,
validateWebSearchConfig
(
cfg
),
"invalid type"
)
}
func
TestValidateWebSearchConfig_NegativeQuotaLimit
(
t
*
testing
.
T
)
{
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
QuotaLimit
:
int64Ptr
(
-
1
)}},
}
require
.
ErrorContains
(
t
,
validateWebSearchConfig
(
cfg
),
"quota_limit must be > 0 or null"
)
}
func
TestValidateWebSearchConfig_DuplicateType
(
t
*
testing
.
T
)
{
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
[]
WebSearchProviderConfig
{
{
Type
:
"brave"
},
{
Type
:
"brave"
},
},
}
require
.
ErrorContains
(
t
,
validateWebSearchConfig
(
cfg
),
"duplicate type"
)
}
func
TestValidateWebSearchConfig_NilQuotaLimit
(
t
*
testing
.
T
)
{
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
QuotaLimit
:
nil
}},
}
require
.
NoError
(
t
,
validateWebSearchConfig
(
cfg
))
}
// --- parseWebSearchConfigJSON ---
func
TestParseWebSearchConfigJSON_ValidJSON
(
t
*
testing
.
T
)
{
raw
:=
`{"enabled":true,"providers":[{"type":"brave","api_key":"sk-xxx"}]}`
cfg
:=
parseWebSearchConfigJSON
(
raw
)
require
.
True
(
t
,
cfg
.
Enabled
)
require
.
Len
(
t
,
cfg
.
Providers
,
1
)
require
.
Equal
(
t
,
"brave"
,
cfg
.
Providers
[
0
]
.
Type
)
}
func
TestParseWebSearchConfigJSON_EmptyString
(
t
*
testing
.
T
)
{
cfg
:=
parseWebSearchConfigJSON
(
""
)
require
.
False
(
t
,
cfg
.
Enabled
)
require
.
Empty
(
t
,
cfg
.
Providers
)
}
func
TestParseWebSearchConfigJSON_InvalidJSON
(
t
*
testing
.
T
)
{
cfg
:=
parseWebSearchConfigJSON
(
"not{json"
)
require
.
False
(
t
,
cfg
.
Enabled
)
require
.
Empty
(
t
,
cfg
.
Providers
)
}
func
TestParseWebSearchConfigJSON_BackwardCompatibility
(
t
*
testing
.
T
)
{
// Old config with priority and quota_refresh_interval should parse without error
raw
:=
`{"enabled":true,"providers":[{"type":"brave","priority":1,"quota_refresh_interval":"monthly","quota_limit":1000}]}`
cfg
:=
parseWebSearchConfigJSON
(
raw
)
require
.
True
(
t
,
cfg
.
Enabled
)
require
.
Len
(
t
,
cfg
.
Providers
,
1
)
require
.
Equal
(
t
,
int64
(
1000
),
*
cfg
.
Providers
[
0
]
.
QuotaLimit
)
}
// --- SanitizeWebSearchConfig ---
func
TestSanitizeWebSearchConfig_MaskAPIKey
(
t
*
testing
.
T
)
{
cfg
:=
&
WebSearchEmulationConfig
{
Enabled
:
true
,
Providers
:
[]
WebSearchProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"sk-secret-xxx"
},
},
}
out
:=
SanitizeWebSearchConfig
(
context
.
Background
(),
cfg
)
require
.
Equal
(
t
,
""
,
out
.
Providers
[
0
]
.
APIKey
)
require
.
True
(
t
,
out
.
Providers
[
0
]
.
APIKeyConfigured
)
}
func
TestSanitizeWebSearchConfig_NoAPIKey
(
t
*
testing
.
T
)
{
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
APIKey
:
""
}},
}
out
:=
SanitizeWebSearchConfig
(
context
.
Background
(),
cfg
)
require
.
Equal
(
t
,
""
,
out
.
Providers
[
0
]
.
APIKey
)
require
.
False
(
t
,
out
.
Providers
[
0
]
.
APIKeyConfigured
)
}
func
TestSanitizeWebSearchConfig_Nil
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
SanitizeWebSearchConfig
(
context
.
Background
(),
nil
))
}
func
TestSanitizeWebSearchConfig_PreservesOtherFields
(
t
*
testing
.
T
)
{
cfg
:=
&
WebSearchEmulationConfig
{
Enabled
:
true
,
Providers
:
[]
WebSearchProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"secret"
,
QuotaLimit
:
int64Ptr
(
1000
)},
},
}
out
:=
SanitizeWebSearchConfig
(
context
.
Background
(),
cfg
)
require
.
True
(
t
,
out
.
Enabled
)
require
.
Equal
(
t
,
int64
(
1000
),
*
out
.
Providers
[
0
]
.
QuotaLimit
)
}
func
TestSanitizeWebSearchConfig_DoesNotMutateOriginal
(
t
*
testing
.
T
)
{
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
[]
WebSearchProviderConfig
{{
Type
:
"brave"
,
APIKey
:
"secret"
}},
}
_
=
SanitizeWebSearchConfig
(
context
.
Background
(),
cfg
)
require
.
Equal
(
t
,
"secret"
,
cfg
.
Providers
[
0
]
.
APIKey
)
}
// --- PopulateWebSearchUsage ---
func
TestPopulateWebSearchUsage_NilInput
(
t
*
testing
.
T
)
{
require
.
Nil
(
t
,
PopulateWebSearchUsage
(
context
.
Background
(),
nil
))
}
func
TestPopulateWebSearchUsage_NoManager_QuotaUsedZero
(
t
*
testing
.
T
)
{
// Ensure no global manager is set
SetWebSearchManager
(
nil
)
defer
SetWebSearchManager
(
nil
)
cfg
:=
&
WebSearchEmulationConfig
{
Enabled
:
true
,
Providers
:
[]
WebSearchProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"sk-key"
,
QuotaLimit
:
int64Ptr
(
1000
)},
},
}
out
:=
PopulateWebSearchUsage
(
context
.
Background
(),
cfg
)
require
.
NotNil
(
t
,
out
)
require
.
Len
(
t
,
out
.
Providers
,
1
)
require
.
Equal
(
t
,
int64
(
0
),
out
.
Providers
[
0
]
.
QuotaUsed
)
}
func
TestPopulateWebSearchUsage_APIKeyConfigured_True
(
t
*
testing
.
T
)
{
SetWebSearchManager
(
nil
)
defer
SetWebSearchManager
(
nil
)
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
[]
WebSearchProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"sk-key"
},
},
}
out
:=
PopulateWebSearchUsage
(
context
.
Background
(),
cfg
)
require
.
True
(
t
,
out
.
Providers
[
0
]
.
APIKeyConfigured
)
}
func
TestPopulateWebSearchUsage_APIKeyConfigured_False
(
t
*
testing
.
T
)
{
SetWebSearchManager
(
nil
)
defer
SetWebSearchManager
(
nil
)
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
[]
WebSearchProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
""
},
},
}
out
:=
PopulateWebSearchUsage
(
context
.
Background
(),
cfg
)
require
.
False
(
t
,
out
.
Providers
[
0
]
.
APIKeyConfigured
)
}
func
TestPopulateWebSearchUsage_NilQuotaLimit
(
t
*
testing
.
T
)
{
SetWebSearchManager
(
nil
)
defer
SetWebSearchManager
(
nil
)
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
[]
WebSearchProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"sk-key"
,
QuotaLimit
:
nil
},
},
}
out
:=
PopulateWebSearchUsage
(
context
.
Background
(),
cfg
)
require
.
Nil
(
t
,
out
.
Providers
[
0
]
.
QuotaLimit
)
}
func
TestPopulateWebSearchUsage_NonNilQuotaLimit
(
t
*
testing
.
T
)
{
SetWebSearchManager
(
nil
)
defer
SetWebSearchManager
(
nil
)
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
[]
WebSearchProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"sk-key"
,
QuotaLimit
:
int64Ptr
(
500
)},
},
}
out
:=
PopulateWebSearchUsage
(
context
.
Background
(),
cfg
)
require
.
NotNil
(
t
,
out
.
Providers
[
0
]
.
QuotaLimit
)
require
.
Equal
(
t
,
int64
(
500
),
*
out
.
Providers
[
0
]
.
QuotaLimit
)
}
func
TestPopulateWebSearchUsage_WithManager_NilRedis
(
t
*
testing
.
T
)
{
// Manager with nil Redis returns 0 usage without error
mgr
:=
websearch
.
NewManager
([]
websearch
.
ProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"k"
},
},
nil
)
SetWebSearchManager
(
mgr
)
defer
SetWebSearchManager
(
nil
)
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
[]
WebSearchProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"sk-key"
,
QuotaLimit
:
int64Ptr
(
1000
)},
},
}
out
:=
PopulateWebSearchUsage
(
context
.
Background
(),
cfg
)
require
.
Equal
(
t
,
int64
(
0
),
out
.
Providers
[
0
]
.
QuotaUsed
)
require
.
True
(
t
,
out
.
Providers
[
0
]
.
APIKeyConfigured
)
}
func
TestPopulateWebSearchUsage_DoesNotMutateOriginal
(
t
*
testing
.
T
)
{
SetWebSearchManager
(
nil
)
defer
SetWebSearchManager
(
nil
)
cfg
:=
&
WebSearchEmulationConfig
{
Providers
:
[]
WebSearchProviderConfig
{
{
Type
:
"brave"
,
APIKey
:
"secret"
,
QuotaLimit
:
int64Ptr
(
100
)},
},
}
_
=
PopulateWebSearchUsage
(
context
.
Background
(),
cfg
)
// Original should be unchanged
require
.
Equal
(
t
,
"secret"
,
cfg
.
Providers
[
0
]
.
APIKey
)
require
.
Equal
(
t
,
int64
(
0
),
cfg
.
Providers
[
0
]
.
QuotaUsed
)
}
// --- ResetWebSearchUsage ---
func
TestResetWebSearchUsage_NilManager
(
t
*
testing
.
T
)
{
SetWebSearchManager
(
nil
)
defer
SetWebSearchManager
(
nil
)
err
:=
ResetWebSearchUsage
(
context
.
Background
(),
"brave"
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not initialized"
)
}
backend/internal/service/wire.go
View file @
0b746501
...
...
@@ -373,10 +373,11 @@ func ProvideBackupService(
return
svc
}
// ProvideSettingService wires SettingService with group reader
for default subscription validation
.
func
ProvideSettingService
(
settingRepo
SettingRepository
,
groupRepo
GroupRepository
,
cfg
*
config
.
Config
)
*
SettingService
{
// ProvideSettingService wires SettingService with group reader
and proxy repo
.
func
ProvideSettingService
(
settingRepo
SettingRepository
,
groupRepo
GroupRepository
,
proxyRepo
ProxyRepository
,
cfg
*
config
.
Config
)
*
SettingService
{
svc
:=
NewSettingService
(
settingRepo
,
cfg
)
svc
.
SetDefaultSubscriptionGroupReader
(
groupRepo
)
svc
.
SetProxyRepository
(
proxyRepo
)
return
svc
}
...
...
@@ -465,6 +466,7 @@ var ProviderSet = wire.NewSet(
ProvidePaymentConfigService
,
NewPaymentService
,
ProvidePaymentOrderExpiryService
,
ProvideBalanceNotifyService
,
)
// ProvidePaymentConfigService wraps NewPaymentConfigService to accept the named
...
...
@@ -473,6 +475,11 @@ func ProvidePaymentConfigService(entClient *dbent.Client, settingRepo SettingRep
return
NewPaymentConfigService
(
entClient
,
settingRepo
,
[]
byte
(
key
))
}
// ProvideBalanceNotifyService creates BalanceNotifyService
func
ProvideBalanceNotifyService
(
emailService
*
EmailService
,
settingRepo
SettingRepository
,
accountRepo
AccountRepository
)
*
BalanceNotifyService
{
return
NewBalanceNotifyService
(
emailService
,
settingRepo
,
accountRepo
)
}
// ProvidePaymentOrderExpiryService creates and starts PaymentOrderExpiryService.
func
ProvidePaymentOrderExpiryService
(
paymentSvc
*
PaymentService
)
*
PaymentOrderExpiryService
{
svc
:=
NewPaymentOrderExpiryService
(
paymentSvc
,
60
*
time
.
Second
)
...
...
backend/internal/web/embed_on.go
View file @
0b746501
...
...
@@ -10,6 +10,8 @@ import (
"io"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"
"time"
...
...
@@ -32,11 +34,12 @@ type PublicSettingsProvider interface {
// FrontendServer serves the embedded frontend with settings injection
type
FrontendServer
struct
{
distFS
fs
.
FS
fileServer
http
.
Handler
baseHTML
[]
byte
cache
*
HTMLCache
settings
PublicSettingsProvider
distFS
fs
.
FS
fileServer
http
.
Handler
baseHTML
[]
byte
cache
*
HTMLCache
settings
PublicSettingsProvider
overrideDir
string
// local file override directory
}
// NewFrontendServer creates a new frontend server with settings injection
...
...
@@ -62,11 +65,12 @@ func NewFrontendServer(settingsProvider PublicSettingsProvider) (*FrontendServer
cache
.
SetBaseHTML
(
baseHTML
)
return
&
FrontendServer
{
distFS
:
distFS
,
fileServer
:
http
.
FileServer
(
http
.
FS
(
distFS
)),
baseHTML
:
baseHTML
,
cache
:
cache
,
settings
:
settingsProvider
,
distFS
:
distFS
,
fileServer
:
http
.
FileServer
(
http
.
FS
(
distFS
)),
baseHTML
:
baseHTML
,
cache
:
cache
,
settings
:
settingsProvider
,
overrideDir
:
filepath
.
Join
(
"data"
,
"public"
),
},
nil
}
...
...
@@ -99,6 +103,11 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc {
return
}
// Try local override first
if
s
.
tryServeOverride
(
c
,
cleanPath
)
{
return
}
// Serve static files normally
s
.
fileServer
.
ServeHTTP
(
c
.
Writer
,
c
.
Request
)
c
.
Abort
()
...
...
@@ -114,6 +123,22 @@ func (s *FrontendServer) fileExists(path string) bool {
return
true
}
// tryServeOverride checks if a local override file exists and serves it.
// Files in overrideDir take precedence over embedded files.
func
(
s
*
FrontendServer
)
tryServeOverride
(
c
*
gin
.
Context
,
cleanPath
string
)
bool
{
if
s
.
overrideDir
==
""
{
return
false
}
filePath
:=
filepath
.
Join
(
s
.
overrideDir
,
filepath
.
Clean
(
"/"
+
cleanPath
))
info
,
err
:=
os
.
Stat
(
filePath
)
if
err
!=
nil
||
info
.
IsDir
()
{
return
false
}
c
.
File
(
filePath
)
c
.
Abort
()
return
true
}
func
(
s
*
FrontendServer
)
serveIndexHTML
(
c
*
gin
.
Context
)
{
// Get nonce from context (generated by SecurityHeaders middleware)
nonce
:=
middleware
.
GetNonceFromContext
(
c
)
...
...
@@ -226,6 +251,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
panic
(
"failed to get dist subdirectory: "
+
err
.
Error
())
}
fileServer
:=
http
.
FileServer
(
http
.
FS
(
distFS
))
overrideDir
:=
filepath
.
Join
(
"data"
,
"public"
)
return
func
(
c
*
gin
.
Context
)
{
path
:=
c
.
Request
.
URL
.
Path
...
...
@@ -242,6 +268,10 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
if
file
,
err
:=
distFS
.
Open
(
cleanPath
);
err
==
nil
{
_
=
file
.
Close
()
// Try local override first
if
tryServeOverrideFile
(
c
,
overrideDir
,
cleanPath
)
{
return
}
fileServer
.
ServeHTTP
(
c
.
Writer
,
c
.
Request
)
c
.
Abort
()
return
...
...
@@ -251,6 +281,21 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
}
}
// tryServeOverrideFile is a standalone version of tryServeOverride for legacy usage.
func
tryServeOverrideFile
(
c
*
gin
.
Context
,
overrideDir
,
cleanPath
string
)
bool
{
if
overrideDir
==
""
{
return
false
}
filePath
:=
filepath
.
Join
(
overrideDir
,
filepath
.
Clean
(
"/"
+
cleanPath
))
info
,
err
:=
os
.
Stat
(
filePath
)
if
err
!=
nil
||
info
.
IsDir
()
{
return
false
}
c
.
File
(
filePath
)
c
.
Abort
()
return
true
}
func
shouldBypassEmbeddedFrontend
(
path
string
)
bool
{
trimmed
:=
strings
.
TrimSpace
(
path
)
return
strings
.
HasPrefix
(
trimmed
,
"/api/"
)
||
...
...
backend/migration_refuse_oauth/refuse_oauth.sql
0 → 100644
View file @
0b746501
-- 在测试库执行完migration_release之后,执行这个语句,阻止token刷新
-- 切记不可在生产库执行
UPDATE
accounts
SET
schedulable
=
false
,
credentials
=
NUll
WHERE
type
=
'oauth'
;
\ No newline at end of file
Prev
1
…
3
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment