Unverified Commit d402e722 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1637 from touwaeriol/feat/websearch-notify-pricing

feat: web search emulation, balance/quota notify, account stats pricing, per-provider refund control, Stripe fix / Web 搜索模拟、余额配额通知、渠道统计计费、按服务商退款控制、Stripe 修复
parents e534e9ba 8548a130
......@@ -617,6 +617,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil,
nil,
nil,
nil,
)
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
......
......@@ -64,12 +64,9 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts
if acc.ID <= 0 {
continue
}
c := acc.Concurrency
if c <= 0 {
c = 1
}
if prev, ok := unique[acc.ID]; !ok || c > prev {
unique[acc.ID] = c
lf := acc.EffectiveLoadFactor()
if prev, ok := unique[acc.ID]; !ok || lf > prev {
unique[acc.ID] = lf
}
}
......
......@@ -391,7 +391,7 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con
}
batch = append(batch, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
MaxConcurrency: acc.EffectiveLoadFactor(),
})
}
if len(batch) == 0 {
......
......@@ -183,6 +183,15 @@ func TestOpsSystemLogSink_StartStopAndFlushSuccess(t *testing.T) {
if strings.TrimSpace(item.Message) == "" {
t.Fatalf("message should not be empty")
}
// writtenCount is incremented after BatchInsertSystemLogsFn returns,
// so poll briefly to avoid a race between the done signal and the atomic add.
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
if sink.Health().WrittenCount > 0 {
break
}
time.Sleep(time.Millisecond)
}
health := sink.Health()
if health.WrittenCount == 0 {
t.Fatalf("written_count should be >0")
......
......@@ -3,6 +3,7 @@ package service
import (
"context"
"fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
......@@ -10,6 +11,52 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// validatePlanRequired checks that all required fields for a plan are provided.
func validatePlanRequired(name string, groupID int64, price float64, validityDays int, validityUnit string, originalPrice *float64) error {
if strings.TrimSpace(name) == "" {
return infraerrors.BadRequest("PLAN_NAME_REQUIRED", "plan name is required")
}
if groupID <= 0 {
return infraerrors.BadRequest("PLAN_GROUP_REQUIRED", "group is required")
}
if price <= 0 {
return infraerrors.BadRequest("PLAN_PRICE_INVALID", "price must be > 0")
}
if validityDays <= 0 {
return infraerrors.BadRequest("PLAN_VALIDITY_REQUIRED", "validity days must be > 0")
}
if strings.TrimSpace(validityUnit) == "" {
return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required")
}
if originalPrice != nil && *originalPrice < 0 {
return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0")
}
return nil
}
// validatePlanPatch validates only the non-nil fields in a patch update.
func validatePlanPatch(req UpdatePlanRequest) error {
if req.Name != nil && strings.TrimSpace(*req.Name) == "" {
return infraerrors.BadRequest("PLAN_NAME_REQUIRED", "plan name is required")
}
if req.GroupID != nil && *req.GroupID <= 0 {
return infraerrors.BadRequest("PLAN_GROUP_REQUIRED", "group is required")
}
if req.Price != nil && *req.Price <= 0 {
return infraerrors.BadRequest("PLAN_PRICE_INVALID", "price must be > 0")
}
if req.ValidityDays != nil && *req.ValidityDays <= 0 {
return infraerrors.BadRequest("PLAN_VALIDITY_REQUIRED", "validity days must be > 0")
}
if req.ValidityUnit != nil && strings.TrimSpace(*req.ValidityUnit) == "" {
return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required")
}
if req.OriginalPrice != nil && *req.OriginalPrice < 0 {
return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0")
}
return nil
}
// --- Plan CRUD ---
// PlanGroupInfo holds the group details needed for subscription plan display.
......@@ -74,6 +121,9 @@ func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.S
}
func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) {
if err := validatePlanRequired(req.Name, req.GroupID, req.Price, req.ValidityDays, req.ValidityUnit, req.OriginalPrice); err != nil {
return nil, err
}
b := s.entClient.SubscriptionPlan.Create().
SetGroupID(req.GroupID).SetName(req.Name).SetDescription(req.Description).
SetPrice(req.Price).SetValidityDays(req.ValidityDays).SetValidityUnit(req.ValidityUnit).
......@@ -86,8 +136,12 @@ func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanReq
}
// UpdatePlan updates a subscription plan by ID (patch semantics).
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update boilerplate.
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update boilerplate
// plus a validation guard for non-nil fields.
func (s *PaymentConfigService) UpdatePlan(ctx context.Context, id int64, req UpdatePlanRequest) (*dbent.SubscriptionPlan, error) {
if err := validatePlanPatch(req); err != nil {
return nil, err
}
u := s.entClient.SubscriptionPlan.UpdateOneID(id)
if req.GroupID != nil {
u.SetGroupID(*req.GroupID)
......
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestValidatePlanRequired_AllValid(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", nil)
require.NoError(t, err)
}
func TestValidatePlanRequired_EmptyName(t *testing.T) {
err := validatePlanRequired("", 1, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanRequired_WhitespaceName(t *testing.T) {
err := validatePlanRequired(" ", 1, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanRequired_ZeroGroupID(t *testing.T) {
err := validatePlanRequired("Pro", 0, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanRequired_NegativeGroupID(t *testing.T) {
err := validatePlanRequired("Pro", -1, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanRequired_ZeroPrice(t *testing.T) {
err := validatePlanRequired("Pro", 1, 0, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanRequired_NegativePrice(t *testing.T) {
err := validatePlanRequired("Pro", 1, -5, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanRequired_ZeroValidityDays(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 0, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanRequired_NegativeValidityDays(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, -7, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanRequired_EmptyValidityUnit(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 30, "", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanRequired_WhitespaceValidityUnit(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 30, " ", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanRequired_NameValidatedFirst(t *testing.T) {
err := validatePlanRequired("", 0, 0, 0, "", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanRequired_TrimmedValidName(t *testing.T) {
err := validatePlanRequired(" Pro ", 1, 9.99, 30, "days", nil)
require.NoError(t, err)
}
func TestValidatePlanRequired_NegativeOriginalPrice(t *testing.T) {
neg := -10.0
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &neg)
require.Error(t, err)
require.Contains(t, err.Error(), "original price")
}
func TestValidatePlanRequired_ZeroOriginalPrice(t *testing.T) {
zero := 0.0
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &zero)
require.NoError(t, err)
}
func TestValidatePlanRequired_ValidOriginalPrice(t *testing.T) {
op := 19.99
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &op)
require.NoError(t, err)
}
// --- validatePlanPatch tests ---
func TestValidatePlanPatch_NegativeOriginalPrice(t *testing.T) {
neg := -5.0
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &neg})
require.Error(t, err)
require.Contains(t, err.Error(), "original price")
}
func TestValidatePlanPatch_ZeroOriginalPrice(t *testing.T) {
zero := 0.0
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &zero})
require.NoError(t, err)
}
func TestValidatePlanPatch_ValidOriginalPrice(t *testing.T) {
op := 29.99
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &op})
require.NoError(t, err)
}
func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil})
require.NoError(t, err)
}
// --- validatePlanPatch: other fields ---
func ptrStr(s string) *string { return &s }
func ptrInt(i int) *int { return &i }
func ptrInt64(i int64) *int64 { return &i }
func ptrFloat(f float64) *float64 { return &f }
func TestValidatePlanPatch_EmptyName(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("")})
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanPatch_ValidName(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("Basic")})
require.NoError(t, err)
}
func TestValidatePlanPatch_ZeroGroupID(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{GroupID: ptrInt64(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanPatch_NegativePrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(-1)})
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanPatch_ZeroPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanPatch_ValidPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(9.99)})
require.NoError(t, err)
}
func TestValidatePlanPatch_ZeroValidityDays(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityDays: ptrInt(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanPatch_EmptyValidityUnit(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("")})
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanPatch_ValidValidityUnit(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("days")})
require.NoError(t, err)
}
func TestValidatePlanPatch_AllNil(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{})
require.NoError(t, err)
}
......@@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db
// ProviderInstanceResponse is the API response for a provider instance.
type ProviderInstanceResponse struct {
ID int64 `json:"id"`
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Limits string `json:"limits"`
Enabled bool `json:"enabled"`
RefundEnabled bool `json:"refund_enabled"`
SortOrder int `json:"sort_order"`
PaymentMode string `json:"payment_mode"`
ID int64 `json:"id"`
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Limits string `json:"limits"`
Enabled bool `json:"enabled"`
RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
SortOrder int `json:"sort_order"`
PaymentMode string `json:"payment_mode"`
}
// ListProviderInstancesWithConfig returns provider instances with decrypted config.
......@@ -46,8 +47,9 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
resp := ProviderInstanceResponse{
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, SortOrder: inst.SortOrder,
PaymentMode: inst.PaymentMode,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled,
AllowUserRefund: inst.AllowUserRefund,
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
}
resp.Config, err = s.decryptAndMaskConfig(inst.Config)
if err != nil {
......@@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
if err != nil {
return nil, err
}
allowUserRefund := req.AllowUserRefund && req.RefundEnabled
return s.entClient.PaymentProviderInstance.Create().
SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode).
SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
SetAllowUserRefund(allowUserRefund).
Save(ctx)
}
......@@ -221,6 +225,29 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
if req.RefundEnabled != nil {
u.SetRefundEnabled(*req.RefundEnabled)
// Cascade: turning off refund_enabled also disables allow_user_refund
if !*req.RefundEnabled {
u.SetAllowUserRefund(false)
}
}
if req.AllowUserRefund != nil {
// Only allow enabling when refund_enabled is (or will be) true
if *req.AllowUserRefund {
refundEnabled := false
if req.RefundEnabled != nil {
refundEnabled = *req.RefundEnabled
} else {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err == nil {
refundEnabled = inst.RefundEnabled
}
}
if refundEnabled {
u.SetAllowUserRefund(true)
}
} else {
u.SetAllowUserRefund(false)
}
}
if req.PaymentMode != nil {
u.SetPaymentMode(*req.PaymentMode)
......@@ -228,6 +255,23 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
return u.Save(ctx)
}
// GetUserRefundEligibleInstanceIDs returns provider instance IDs that allow user refund.
func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Context) ([]string, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.RefundEnabledEQ(true),
paymentproviderinstance.AllowUserRefundEQ(true),
).Select(paymentproviderinstance.FieldID).All(ctx)
if err != nil {
return nil, err
}
ids := make([]string, 0, len(instances))
for _, inst := range instances {
ids = append(ids, strconv.FormatInt(int64(inst.ID), 10))
}
return ids, nil
}
func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newConfig map[string]string) (map[string]string, error) {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil {
......
......@@ -101,7 +101,7 @@ func TestIsSensitiveConfigField(t *testing.T) {
t.Parallel()
tests := []struct {
field string
field string
wantSen bool
}{
// Sensitive fields (contain key/secret/private/password/pkey patterns)
......
......@@ -105,26 +105,28 @@ type MethodLimitsResponse struct {
}
type CreateProviderInstanceRequest struct {
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Enabled bool `json:"enabled"`
PaymentMode string `json:"payment_mode"`
SortOrder int `json:"sort_order"`
Limits string `json:"limits"`
RefundEnabled bool `json:"refund_enabled"`
ProviderKey string `json:"provider_key"`
Name string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Enabled bool `json:"enabled"`
PaymentMode string `json:"payment_mode"`
SortOrder int `json:"sort_order"`
Limits string `json:"limits"`
RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
}
type UpdateProviderInstanceRequest struct {
Name *string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Enabled *bool `json:"enabled"`
PaymentMode *string `json:"payment_mode"`
SortOrder *int `json:"sort_order"`
Limits *string `json:"limits"`
RefundEnabled *bool `json:"refund_enabled"`
Name *string `json:"name"`
Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"`
Enabled *bool `json:"enabled"`
PaymentMode *string `json:"payment_mode"`
SortOrder *int `json:"sort_order"`
Limits *string `json:"limits"`
RefundEnabled *bool `json:"refund_enabled"`
AllowUserRefund *bool `json:"allow_user_refund"`
}
type CreatePlanRequest struct {
GroupID int64 `json:"group_id"`
......
......@@ -2,6 +2,7 @@ package service
import (
"context"
"errors"
"fmt"
"log/slog"
"math"
......@@ -17,6 +18,19 @@ import (
// --- Refund Flow ---
// getOrderProviderInstance looks up the provider instance that processed this order.
// Returns nil, nil for legacy orders without provider_instance_id.
func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" {
return nil, nil
}
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if err != nil {
return nil, nil
}
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
}
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil {
......@@ -57,6 +71,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
if o.Status != OrderStatusCompleted {
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
}
// Check provider instance allows user refund
inst, err := s.getOrderProviderInstance(ctx, o)
if err != nil || inst == nil {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order")
}
if !inst.AllowUserRefund {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "user refund is not enabled for this provider")
}
return o, nil
}
......@@ -69,6 +91,19 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
if !psSliceContains(ok, o.Status) {
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
}
// Check provider instance allows admin refund
inst, instErr := s.getOrderProviderInstance(ctx, o)
if instErr != nil {
slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr)
return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order")
}
if inst == nil {
// Legacy order without provider_instance_id — block refund
return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not available for this order")
}
if !inst.RefundEnabled {
return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider")
}
if math.IsNaN(amt) || math.IsInf(amt, 0) {
return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount")
}
......@@ -150,11 +185,17 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
_, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct)
if err != nil {
// If deducting would expire the subscription, revoke it entirely
slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct)
if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil {
if errors.Is(err, ErrAdjustWouldExpire) {
// Deduction would expire the subscription — revoke it entirely
slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct)
if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil {
s.restoreStatus(ctx, p)
return nil, fmt.Errorf("revoke subscription: %w", revokeErr)
}
} else {
// Other errors (DB failure, not found) — abort refund
s.restoreStatus(ctx, p)
return nil, fmt.Errorf("revoke subscription: %w", revokeErr)
return nil, fmt.Errorf("deduct subscription days: %w", err)
}
}
} else {
......
......@@ -102,6 +102,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t
})
}
// TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError
// OpenAI OAuth 401 缓存失效出错时仍走 temp_unschedulable
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
......@@ -109,7 +111,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 101,
Platform: PlatformGemini,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
......
......@@ -99,13 +99,19 @@ type DefaultSubscriptionGroupReader interface {
GetByID(ctx context.Context, id int64) (*Group, error)
}
// WebSearchManagerBuilder creates a websearch.Manager from config (injected by infra layer).
// proxyURLs maps proxy ID to resolved URL for provider-level proxy support.
type WebSearchManagerBuilder func(cfg *WebSearchEmulationConfig, proxyURLs map[int64]string)
// SettingService 系统设置服务
type SettingService struct {
settingRepo SettingRepository
defaultSubGroupReader DefaultSubscriptionGroupReader
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
settingRepo SettingRepository
defaultSubGroupReader DefaultSubscriptionGroupReader
proxyRepo ProxyRepository // for resolving websearch provider proxy URLs
cfg *config.Config
onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
webSearchManagerBuilder WebSearchManagerBuilder
}
// NewSettingService 创建系统设置服务实例
......@@ -121,6 +127,11 @@ func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscri
s.defaultSubGroupReader = reader
}
// SetProxyRepository injects a proxy repo for resolving websearch provider proxy URLs.
func (s *SettingService) SetProxyRepository(repo ProxyRepository) {
s.proxyRepo = repo
}
// GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx)
......@@ -168,9 +179,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled,
SettingPaymentEnabled,
SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName,
SettingPaymentEnabled,
SettingKeyBalanceLowNotifyEnabled,
SettingKeyBalanceLowNotifyThreshold,
SettingKeyBalanceLowNotifyRechargeURL,
SettingKeyAccountQuotaNotifyEnabled,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
......@@ -209,6 +224,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
settings[SettingKeyTablePageSizeOptions],
)
var balanceLowNotifyThreshold float64
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
balanceLowNotifyThreshold = v
}
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled,
......@@ -235,9 +255,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
OIDCOAuthProviderName: oidcProviderName,
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true",
AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true",
BalanceLowNotifyThreshold: balanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings[SettingKeyBalanceLowNotifyRechargeURL],
}, nil
}
......@@ -287,10 +311,14 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version,omitempty"`
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
}{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
......@@ -317,10 +345,14 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
PaymentEnabled: settings.PaymentEnabled,
Version: s.version,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
}, nil
}
......@@ -595,6 +627,13 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
// Balance low notification
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64)
updates[SettingKeyBalanceLowNotifyRechargeURL] = settings.BalanceLowNotifyRechargeURL
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil {
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
......@@ -1217,6 +1256,30 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
// Web search emulation: quick enabled check from the JSON config
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
var wsCfg WebSearchEmulationConfig
if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil {
result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
}
}
// Balance low notification
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
result.BalanceLowNotifyThreshold = v
}
result.BalanceLowNotifyRechargeURL = settings[SettingKeyBalanceLowNotifyRechargeURL]
// Account quota notification
result.AccountQuotaNotifyEnabled = settings[SettingKeyAccountQuotaNotifyEnabled] == "true"
if raw := strings.TrimSpace(settings[SettingKeyAccountQuotaNotifyEmails]); raw != "" {
result.AccountQuotaNotifyEmails = ParseNotifyEmails(raw)
}
if result.AccountQuotaNotifyEmails == nil {
result.AccountQuotaNotifyEmails = []NotifyEmailEntry{}
}
return result
}
......
......@@ -66,7 +66,7 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis
func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) {
repo := &settingPublicRepoStub{
values: map[string]string{
SettingKeyTableDefaultPageSize: "50",
SettingKeyTableDefaultPageSize: "50",
SettingKeyTablePageSizeOptions: "[20,50,100]",
},
}
......
......@@ -208,7 +208,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateSettings(context.Background(), &SystemSettings{
TableDefaultPageSize: 50,
TableDefaultPageSize: 50,
TablePageSizeOptions: []int{20, 50, 100},
})
require.NoError(t, err)
......@@ -216,7 +216,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
require.Equal(t, "[20,50,100]", repo.updates[SettingKeyTablePageSizeOptions])
err = svc.UpdateSettings(context.Background(), &SystemSettings{
TableDefaultPageSize: 1000,
TableDefaultPageSize: 1000,
TablePageSizeOptions: []int{20, 100},
})
require.NoError(t, err)
......
......@@ -106,6 +106,18 @@ type SystemSettings struct {
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
// Balance low notification
BalanceLowNotifyEnabled bool
BalanceLowNotifyThreshold float64
BalanceLowNotifyRechargeURL string
// Account quota notification
AccountQuotaNotifyEnabled bool
AccountQuotaNotifyEmails []NotifyEmailEntry
}
type DefaultSubscriptionSetting struct {
......@@ -141,10 +153,15 @@ type PublicSettings struct {
LinuxDoOAuthEnabled bool
BackendModeEnabled bool
PaymentEnabled bool
OIDCOAuthEnabled bool
OIDCOAuthProviderName string
PaymentEnabled bool
Version string
BalanceLowNotifyEnabled bool
AccountQuotaNotifyEnabled bool
BalanceLowNotifyThreshold float64
BalanceLowNotifyRechargeURL string
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
......
......@@ -100,9 +100,22 @@ func valueOrZero(v *int64) int64 {
return *v
}
// AccountQuotaState holds the post-increment quota state returned by the DB transaction.
// All values are post-update (i.e., already include the increment).
type AccountQuotaState struct {
TotalUsed float64
TotalLimit float64
DailyUsed float64
DailyLimit float64
WeeklyUsed float64
WeeklyLimit float64
}
type UsageBillingApplyResult struct {
Applied bool
APIKeyQuotaExhausted bool
NewBalance *float64 // post-deduction balance (nil = no balance deduction)
QuotaState *AccountQuotaState // post-increment quota state (nil = no quota increment)
}
type UsageBillingRepository interface {
......
......@@ -146,6 +146,8 @@ type UsageLog struct {
RateMultiplier float64
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
AccountRateMultiplier *float64
// AccountStatsCost 账号统计定价预计算费用(nil = 使用默认公式 total_cost × account_rate_multiplier)
AccountStatsCost *float64
BillingType int8
RequestType RequestType
......
......@@ -30,6 +30,13 @@ type User struct {
TotpEnabled bool // 是否启用 TOTP
TotpEnabledAt *time.Time // TOTP 启用时间
// 余额不足通知
BalanceNotifyEnabled bool
BalanceNotifyThresholdType string // "fixed" (default) | "percentage"
BalanceNotifyThreshold *float64
BalanceNotifyExtraEmails []NotifyEmailEntry
TotalRecharged float64
APIKeys []APIKey
Subscriptions []UserSubscription
}
......
......@@ -2,8 +2,10 @@ package service
import (
"context"
"crypto/subtle"
"fmt"
"log"
"log/slog"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
......@@ -11,9 +13,18 @@ import (
)
var (
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
)
const (
maxNotifyEmails = 3 // Maximum number of notification emails per user
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
notifyCodeUserRateWindow = 10 * time.Minute
)
// UserListFilters contains all filter options for listing users
......@@ -58,9 +69,11 @@ type UserRepository interface {
// UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct {
Email *string `json:"email"`
Username *string `json:"username"`
Concurrency *int `json:"concurrency"`
Email *string `json:"email"`
Username *string `json:"username"`
Concurrency *int `json:"concurrency"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
// ChangePasswordRequest 修改密码请求
......@@ -72,14 +85,16 @@ type ChangePasswordRequest struct {
// UserService 用户服务
type UserService struct {
userRepo UserRepository
settingRepo SettingRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCache BillingCache
}
// NewUserService 创建用户服务实例
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
func NewUserService(userRepo UserRepository, settingRepo SettingRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
return &UserService{
userRepo: userRepo,
settingRepo: settingRepo,
authCacheInvalidator: authCacheInvalidator,
billingCache: billingCache,
}
......@@ -132,6 +147,17 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Concurrency = *req.Concurrency
}
if req.BalanceNotifyEnabled != nil {
user.BalanceNotifyEnabled = *req.BalanceNotifyEnabled
}
if req.BalanceNotifyThreshold != nil {
if *req.BalanceNotifyThreshold <= 0 {
user.BalanceNotifyThreshold = nil // clear to system default
} else {
user.BalanceNotifyThreshold = req.BalanceNotifyThreshold
}
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
......@@ -198,10 +224,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
}
if s.billingCache != nil {
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in balance cache invalidation", "user_id", userID, "recover", r)
}
}()
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
slog.Error("invalidate user balance cache failed", "user_id", userID, "error", err)
}
}()
}
......@@ -248,3 +279,229 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
}
return nil
}
// SendNotifyEmailCode sends a verification code to the extra notification email.
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error {
if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil {
return err
}
code, err := emailService.GenerateVerifyCode()
if err != nil {
return fmt.Errorf("generate code: %w", err)
}
// Send email first — if SMTP fails, don't write cache or increment counters,
// so the user is not locked out by cooldown/rate-limit for a code they never received.
if err := s.sendNotifyVerifyEmail(ctx, emailService, email, code); err != nil {
return err
}
if err := saveNotifyVerifyCode(ctx, cache, email, code); err != nil {
return err
}
// Increment user-level counter after successful save
if _, err := cache.IncrNotifyCodeUserRate(ctx, userID, notifyCodeUserRateWindow); err != nil {
slog.Error("failed to increment notify code user rate", "user_id", userID, "error", err)
}
return nil
}
// checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit.
func checkNotifyCodeRateLimit(ctx context.Context, cache EmailCache, userID int64, email string) error {
existing, err := cache.GetNotifyVerifyCode(ctx, email)
if err == nil && existing != nil {
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
return ErrVerifyCodeTooFrequent
}
}
count, err := cache.GetNotifyCodeUserRate(ctx, userID)
if err == nil && count >= notifyCodeUserRateLimit {
return ErrNotifyCodeUserRateLimit
}
return nil
}
// saveNotifyVerifyCode saves the verification code to cache.
func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code string) error {
data := &VerificationCodeData{
Code: code,
Attempts: 0,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(verifyCodeTTL),
}
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err)
}
return nil
}
// sendNotifyVerifyEmail builds and sends the verification email.
func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, email, code string) error {
siteName := "Sub2API"
if s.settingRepo != nil {
if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" {
siteName = name
}
}
subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName)
body := buildNotifyVerifyEmailBody(code, siteName)
return emailService.SendEmail(ctx, email, subject, body)
}
// VerifyAndAddNotifyEmail verifies the code and adds the email to user's extra emails.
func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, email, code string, cache EmailCache) error {
if err := verifyNotifyCode(ctx, cache, email, code); err != nil {
return err
}
_ = cache.DeleteNotifyVerifyCode(ctx, email)
return s.addOrVerifyNotifyEmail(ctx, userID, email)
}
// verifyNotifyCode validates the verification code against the cached data.
func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string) error {
data, err := cache.GetNotifyVerifyCode(ctx, email)
if err != nil || data == nil {
return ErrInvalidVerifyCode
}
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++
remaining := time.Until(data.ExpiresAt)
if remaining <= 0 {
return ErrInvalidVerifyCode
}
if err := cache.SetNotifyVerifyCode(ctx, email, data, remaining); err != nil {
slog.Error("failed to update notify verify code attempts", "email", email, "error", err)
}
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
return ErrInvalidVerifyCode
}
return nil
}
// addOrVerifyNotifyEmail adds the email to user's extra notification emails or marks it as verified.
// Note: concurrent calls for the same user could race on the read-modify-write of
// BalanceNotifyExtraEmails. The window is small (requires two verify flows completing
// simultaneously), and the worst case is a duplicate entry which is harmless.
func (s *UserService) addOrVerifyNotifyEmail(ctx context.Context, userID int64, email string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
for i, e := range user.BalanceNotifyExtraEmails {
if strings.EqualFold(e.Email, email) {
if !e.Verified {
user.BalanceNotifyExtraEmails[i].Verified = true
return s.userRepo.Update(ctx, user)
}
return nil // Already verified
}
}
if len(user.BalanceNotifyExtraEmails) >= maxNotifyEmails {
return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d notification emails allowed", maxNotifyEmails))
}
user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, NotifyEmailEntry{
Email: email,
Disabled: false,
Verified: true,
})
return s.userRepo.Update(ctx, user)
}
// RemoveNotifyEmail removes an email from user's extra notification emails.
func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
filtered := make([]NotifyEmailEntry, 0, len(user.BalanceNotifyExtraEmails))
found := false
for _, e := range user.BalanceNotifyExtraEmails {
if strings.EqualFold(e.Email, email) {
found = true
} else {
filtered = append(filtered, e)
}
}
if !found {
return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found")
}
user.BalanceNotifyExtraEmails = filtered
return s.userRepo.Update(ctx, user)
}
// ToggleNotifyEmail toggles the disabled state of a notification email entry.
func (s *UserService) ToggleNotifyEmail(ctx context.Context, userID int64, email string, disabled bool) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
found := false
for i, e := range user.BalanceNotifyExtraEmails {
if strings.EqualFold(e.Email, email) {
user.BalanceNotifyExtraEmails[i].Disabled = disabled
found = true
break
}
}
if !found {
return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found")
}
return s.userRepo.Update(ctx, user)
}
// notifyVerifyEmailTemplate is the HTML template for notify email verification.
// Format args: siteName, code.
const notifyVerifyEmailTemplate = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.code { font-size: 36px; font-weight: bold; letter-spacing: 8px; color: #333; background-color: #f8f9fa; padding: 20px 30px; border-radius: 8px; display: inline-block; margin: 20px 0; font-family: monospace; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">通知邮箱验证码 / Notification Email Verification</p>
<div class="code">%s</div>
<div class="info">
<p>您正在添加额外的通知邮箱,请输入此验证码完成验证。</p>
<p>You are adding an extra notification email. Please enter this code to verify.</p>
<p>此验证码将在 <strong>15 分钟</strong>后失效。</p>
<p>This code will expire in <strong>15 minutes</strong>.</p>
<p>如果您没有请求此验证码,请忽略此邮件。</p>
<p>If you did not request this code, please ignore this email.</p>
</div>
</div>
<div class="footer">
<p>此邮件由系统自动发送,请勿回复。/ This is an automated message, please do not reply.</p>
</div>
</div>
</body>
</html>`
// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification.
func buildNotifyVerifyEmailBody(code, siteName string) string {
return fmt.Sprintf(notifyVerifyEmailTemplate, siteName, code)
}
......@@ -46,12 +46,12 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
return 0, nil
}
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
// --- mock: APIKeyAuthCacheInvalidator ---
......@@ -117,7 +117,7 @@ func (m *mockBillingCache) InvalidateAPIKeyRateLimit(context.Context, int64) err
func TestUpdateBalance_Success(t *testing.T) {
repo := &mockUserRepo{}
cache := &mockBillingCache{}
svc := NewUserService(repo, nil, cache)
svc := NewUserService(repo, nil, nil, cache)
err := svc.UpdateBalance(context.Background(), 42, 100.0)
require.NoError(t, err)
......@@ -134,7 +134,7 @@ func TestUpdateBalance_Success(t *testing.T) {
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
repo := &mockUserRepo{}
svc := NewUserService(repo, nil, nil) // billingCache = nil
svc := NewUserService(repo, nil, nil, nil) // billingCache = nil
err := svc.UpdateBalance(context.Background(), 1, 50.0)
require.NoError(t, err, "billingCache 为 nil 时不应 panic")
......@@ -143,7 +143,7 @@ func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
repo := &mockUserRepo{}
cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")}
svc := NewUserService(repo, nil, cache)
svc := NewUserService(repo, nil, nil, cache)
err := svc.UpdateBalance(context.Background(), 99, 200.0)
require.NoError(t, err, "缓存失效失败不应影响主流程返回值")
......@@ -157,7 +157,7 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
cache := &mockBillingCache{}
svc := NewUserService(repo, nil, cache)
svc := NewUserService(repo, nil, nil, cache)
err := svc.UpdateBalance(context.Background(), 1, 100.0)
require.Error(t, err, "repo 失败时应返回错误")
......@@ -173,7 +173,7 @@ func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) {
repo := &mockUserRepo{}
auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{}
svc := NewUserService(repo, auth, cache)
svc := NewUserService(repo, nil, auth, cache)
err := svc.UpdateBalance(context.Background(), 77, 300.0)
require.NoError(t, err)
......@@ -194,7 +194,7 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{}
svc := NewUserService(repo, auth, cache)
svc := NewUserService(repo, nil, auth, cache)
require.NotNil(t, svc)
require.Equal(t, repo, svc.userRepo)
require.Equal(t, auth, svc.authCacheInvalidator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment