Unverified Commit d402e722 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1637 from touwaeriol/feat/websearch-notify-pricing

feat: web search emulation, balance/quota notify, account stats pricing, per-provider refund control, Stripe fix / Web 搜索模拟、余额配额通知、渠道统计计费、按服务商退款控制、Stripe 修复
parents e534e9ba 8548a130
...@@ -617,6 +617,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { ...@@ -617,6 +617,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) {
nil, nil,
nil, nil,
nil, nil,
nil,
) )
decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil)
......
...@@ -64,12 +64,9 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts ...@@ -64,12 +64,9 @@ func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts
if acc.ID <= 0 { if acc.ID <= 0 {
continue continue
} }
c := acc.Concurrency lf := acc.EffectiveLoadFactor()
if c <= 0 { if prev, ok := unique[acc.ID]; !ok || lf > prev {
c = 1 unique[acc.ID] = lf
}
if prev, ok := unique[acc.ID]; !ok || c > prev {
unique[acc.ID] = c
} }
} }
......
...@@ -391,7 +391,7 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con ...@@ -391,7 +391,7 @@ func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Con
} }
batch = append(batch, AccountWithConcurrency{ batch = append(batch, AccountWithConcurrency{
ID: acc.ID, ID: acc.ID,
MaxConcurrency: acc.Concurrency, MaxConcurrency: acc.EffectiveLoadFactor(),
}) })
} }
if len(batch) == 0 { if len(batch) == 0 {
......
...@@ -183,6 +183,15 @@ func TestOpsSystemLogSink_StartStopAndFlushSuccess(t *testing.T) { ...@@ -183,6 +183,15 @@ func TestOpsSystemLogSink_StartStopAndFlushSuccess(t *testing.T) {
if strings.TrimSpace(item.Message) == "" { if strings.TrimSpace(item.Message) == "" {
t.Fatalf("message should not be empty") t.Fatalf("message should not be empty")
} }
// writtenCount is incremented after BatchInsertSystemLogsFn returns,
// so poll briefly to avoid a race between the done signal and the atomic add.
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
if sink.Health().WrittenCount > 0 {
break
}
time.Sleep(time.Millisecond)
}
health := sink.Health() health := sink.Health()
if health.WrittenCount == 0 { if health.WrittenCount == 0 {
t.Fatalf("written_count should be >0") t.Fatalf("written_count should be >0")
......
...@@ -3,6 +3,7 @@ package service ...@@ -3,6 +3,7 @@ package service
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/group"
...@@ -10,6 +11,52 @@ import ( ...@@ -10,6 +11,52 @@ import (
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
// validatePlanRequired checks that all required fields for a plan are provided.
func validatePlanRequired(name string, groupID int64, price float64, validityDays int, validityUnit string, originalPrice *float64) error {
if strings.TrimSpace(name) == "" {
return infraerrors.BadRequest("PLAN_NAME_REQUIRED", "plan name is required")
}
if groupID <= 0 {
return infraerrors.BadRequest("PLAN_GROUP_REQUIRED", "group is required")
}
if price <= 0 {
return infraerrors.BadRequest("PLAN_PRICE_INVALID", "price must be > 0")
}
if validityDays <= 0 {
return infraerrors.BadRequest("PLAN_VALIDITY_REQUIRED", "validity days must be > 0")
}
if strings.TrimSpace(validityUnit) == "" {
return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required")
}
if originalPrice != nil && *originalPrice < 0 {
return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0")
}
return nil
}
// validatePlanPatch validates only the non-nil fields in a patch update.
func validatePlanPatch(req UpdatePlanRequest) error {
if req.Name != nil && strings.TrimSpace(*req.Name) == "" {
return infraerrors.BadRequest("PLAN_NAME_REQUIRED", "plan name is required")
}
if req.GroupID != nil && *req.GroupID <= 0 {
return infraerrors.BadRequest("PLAN_GROUP_REQUIRED", "group is required")
}
if req.Price != nil && *req.Price <= 0 {
return infraerrors.BadRequest("PLAN_PRICE_INVALID", "price must be > 0")
}
if req.ValidityDays != nil && *req.ValidityDays <= 0 {
return infraerrors.BadRequest("PLAN_VALIDITY_REQUIRED", "validity days must be > 0")
}
if req.ValidityUnit != nil && strings.TrimSpace(*req.ValidityUnit) == "" {
return infraerrors.BadRequest("PLAN_VALIDITY_UNIT_REQUIRED", "validity unit is required")
}
if req.OriginalPrice != nil && *req.OriginalPrice < 0 {
return infraerrors.BadRequest("PLAN_ORIGINAL_PRICE_INVALID", "original price must be >= 0")
}
return nil
}
// --- Plan CRUD --- // --- Plan CRUD ---
// PlanGroupInfo holds the group details needed for subscription plan display. // PlanGroupInfo holds the group details needed for subscription plan display.
...@@ -74,6 +121,9 @@ func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.S ...@@ -74,6 +121,9 @@ func (s *PaymentConfigService) ListPlansForSale(ctx context.Context) ([]*dbent.S
} }
func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) { func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanRequest) (*dbent.SubscriptionPlan, error) {
if err := validatePlanRequired(req.Name, req.GroupID, req.Price, req.ValidityDays, req.ValidityUnit, req.OriginalPrice); err != nil {
return nil, err
}
b := s.entClient.SubscriptionPlan.Create(). b := s.entClient.SubscriptionPlan.Create().
SetGroupID(req.GroupID).SetName(req.Name).SetDescription(req.Description). SetGroupID(req.GroupID).SetName(req.Name).SetDescription(req.Description).
SetPrice(req.Price).SetValidityDays(req.ValidityDays).SetValidityUnit(req.ValidityUnit). SetPrice(req.Price).SetValidityDays(req.ValidityDays).SetValidityUnit(req.ValidityUnit).
...@@ -86,8 +136,12 @@ func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanReq ...@@ -86,8 +136,12 @@ func (s *PaymentConfigService) CreatePlan(ctx context.Context, req CreatePlanReq
} }
// UpdatePlan updates a subscription plan by ID (patch semantics). // UpdatePlan updates a subscription plan by ID (patch semantics).
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update boilerplate. // NOTE: This function exceeds 30 lines due to per-field nil-check patch update boilerplate
// plus a validation guard for non-nil fields.
func (s *PaymentConfigService) UpdatePlan(ctx context.Context, id int64, req UpdatePlanRequest) (*dbent.SubscriptionPlan, error) { func (s *PaymentConfigService) UpdatePlan(ctx context.Context, id int64, req UpdatePlanRequest) (*dbent.SubscriptionPlan, error) {
if err := validatePlanPatch(req); err != nil {
return nil, err
}
u := s.entClient.SubscriptionPlan.UpdateOneID(id) u := s.entClient.SubscriptionPlan.UpdateOneID(id)
if req.GroupID != nil { if req.GroupID != nil {
u.SetGroupID(*req.GroupID) u.SetGroupID(*req.GroupID)
......
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestValidatePlanRequired_AllValid(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", nil)
require.NoError(t, err)
}
func TestValidatePlanRequired_EmptyName(t *testing.T) {
err := validatePlanRequired("", 1, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanRequired_WhitespaceName(t *testing.T) {
err := validatePlanRequired(" ", 1, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanRequired_ZeroGroupID(t *testing.T) {
err := validatePlanRequired("Pro", 0, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanRequired_NegativeGroupID(t *testing.T) {
err := validatePlanRequired("Pro", -1, 9.99, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanRequired_ZeroPrice(t *testing.T) {
err := validatePlanRequired("Pro", 1, 0, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanRequired_NegativePrice(t *testing.T) {
err := validatePlanRequired("Pro", 1, -5, 30, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanRequired_ZeroValidityDays(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 0, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanRequired_NegativeValidityDays(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, -7, "days", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanRequired_EmptyValidityUnit(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 30, "", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanRequired_WhitespaceValidityUnit(t *testing.T) {
err := validatePlanRequired("Pro", 1, 9.99, 30, " ", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanRequired_NameValidatedFirst(t *testing.T) {
err := validatePlanRequired("", 0, 0, 0, "", nil)
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanRequired_TrimmedValidName(t *testing.T) {
err := validatePlanRequired(" Pro ", 1, 9.99, 30, "days", nil)
require.NoError(t, err)
}
func TestValidatePlanRequired_NegativeOriginalPrice(t *testing.T) {
neg := -10.0
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &neg)
require.Error(t, err)
require.Contains(t, err.Error(), "original price")
}
func TestValidatePlanRequired_ZeroOriginalPrice(t *testing.T) {
zero := 0.0
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &zero)
require.NoError(t, err)
}
func TestValidatePlanRequired_ValidOriginalPrice(t *testing.T) {
op := 19.99
err := validatePlanRequired("Pro", 1, 9.99, 30, "days", &op)
require.NoError(t, err)
}
// --- validatePlanPatch tests ---
func TestValidatePlanPatch_NegativeOriginalPrice(t *testing.T) {
neg := -5.0
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &neg})
require.Error(t, err)
require.Contains(t, err.Error(), "original price")
}
func TestValidatePlanPatch_ZeroOriginalPrice(t *testing.T) {
zero := 0.0
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &zero})
require.NoError(t, err)
}
func TestValidatePlanPatch_ValidOriginalPrice(t *testing.T) {
op := 29.99
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: &op})
require.NoError(t, err)
}
func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil})
require.NoError(t, err)
}
// --- validatePlanPatch: other fields ---
func ptrStr(s string) *string { return &s }
func ptrInt(i int) *int { return &i }
func ptrInt64(i int64) *int64 { return &i }
func ptrFloat(f float64) *float64 { return &f }
func TestValidatePlanPatch_EmptyName(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("")})
require.Error(t, err)
require.Contains(t, err.Error(), "plan name")
}
func TestValidatePlanPatch_ValidName(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("Basic")})
require.NoError(t, err)
}
func TestValidatePlanPatch_ZeroGroupID(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{GroupID: ptrInt64(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "group")
}
func TestValidatePlanPatch_NegativePrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(-1)})
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanPatch_ZeroPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "price")
}
func TestValidatePlanPatch_ValidPrice(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(9.99)})
require.NoError(t, err)
}
func TestValidatePlanPatch_ZeroValidityDays(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityDays: ptrInt(0)})
require.Error(t, err)
require.Contains(t, err.Error(), "validity days")
}
func TestValidatePlanPatch_EmptyValidityUnit(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("")})
require.Error(t, err)
require.Contains(t, err.Error(), "validity unit")
}
func TestValidatePlanPatch_ValidValidityUnit(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("days")})
require.NoError(t, err)
}
func TestValidatePlanPatch_AllNil(t *testing.T) {
err := validatePlanPatch(UpdatePlanRequest{})
require.NoError(t, err)
}
...@@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db ...@@ -22,16 +22,17 @@ func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*db
// ProviderInstanceResponse is the API response for a provider instance. // ProviderInstanceResponse is the API response for a provider instance.
type ProviderInstanceResponse struct { type ProviderInstanceResponse struct {
ID int64 `json:"id"` ID int64 `json:"id"`
ProviderKey string `json:"provider_key"` ProviderKey string `json:"provider_key"`
Name string `json:"name"` Name string `json:"name"`
Config map[string]string `json:"config"` Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"` SupportedTypes []string `json:"supported_types"`
Limits string `json:"limits"` Limits string `json:"limits"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
RefundEnabled bool `json:"refund_enabled"` RefundEnabled bool `json:"refund_enabled"`
SortOrder int `json:"sort_order"` AllowUserRefund bool `json:"allow_user_refund"`
PaymentMode string `json:"payment_mode"` SortOrder int `json:"sort_order"`
PaymentMode string `json:"payment_mode"`
} }
// ListProviderInstancesWithConfig returns provider instances with decrypted config. // ListProviderInstancesWithConfig returns provider instances with decrypted config.
...@@ -46,8 +47,9 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte ...@@ -46,8 +47,9 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
resp := ProviderInstanceResponse{ resp := ProviderInstanceResponse{
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, SortOrder: inst.SortOrder, Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled,
PaymentMode: inst.PaymentMode, AllowUserRefund: inst.AllowUserRefund,
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
} }
resp.Config, err = s.decryptAndMaskConfig(inst.Config) resp.Config, err = s.decryptAndMaskConfig(inst.Config)
if err != nil { if err != nil {
...@@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C ...@@ -110,10 +112,12 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
if err != nil { if err != nil {
return nil, err return nil, err
} }
allowUserRefund := req.AllowUserRefund && req.RefundEnabled
return s.entClient.PaymentProviderInstance.Create(). return s.entClient.PaymentProviderInstance.Create().
SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode). SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode).
SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
SetAllowUserRefund(allowUserRefund).
Save(ctx) Save(ctx)
} }
...@@ -221,6 +225,29 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in ...@@ -221,6 +225,29 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
} }
if req.RefundEnabled != nil { if req.RefundEnabled != nil {
u.SetRefundEnabled(*req.RefundEnabled) u.SetRefundEnabled(*req.RefundEnabled)
// Cascade: turning off refund_enabled also disables allow_user_refund
if !*req.RefundEnabled {
u.SetAllowUserRefund(false)
}
}
if req.AllowUserRefund != nil {
// Only allow enabling when refund_enabled is (or will be) true
if *req.AllowUserRefund {
refundEnabled := false
if req.RefundEnabled != nil {
refundEnabled = *req.RefundEnabled
} else {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err == nil {
refundEnabled = inst.RefundEnabled
}
}
if refundEnabled {
u.SetAllowUserRefund(true)
}
} else {
u.SetAllowUserRefund(false)
}
} }
if req.PaymentMode != nil { if req.PaymentMode != nil {
u.SetPaymentMode(*req.PaymentMode) u.SetPaymentMode(*req.PaymentMode)
...@@ -228,6 +255,23 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in ...@@ -228,6 +255,23 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
return u.Save(ctx) return u.Save(ctx)
} }
// GetUserRefundEligibleInstanceIDs returns provider instance IDs that allow user refund.
func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Context) ([]string, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.RefundEnabledEQ(true),
paymentproviderinstance.AllowUserRefundEQ(true),
).Select(paymentproviderinstance.FieldID).All(ctx)
if err != nil {
return nil, err
}
ids := make([]string, 0, len(instances))
for _, inst := range instances {
ids = append(ids, strconv.FormatInt(int64(inst.ID), 10))
}
return ids, nil
}
func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newConfig map[string]string) (map[string]string, error) { func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newConfig map[string]string) (map[string]string, error) {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil { if err != nil {
......
...@@ -101,7 +101,7 @@ func TestIsSensitiveConfigField(t *testing.T) { ...@@ -101,7 +101,7 @@ func TestIsSensitiveConfigField(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
field string field string
wantSen bool wantSen bool
}{ }{
// Sensitive fields (contain key/secret/private/password/pkey patterns) // Sensitive fields (contain key/secret/private/password/pkey patterns)
......
...@@ -105,26 +105,28 @@ type MethodLimitsResponse struct { ...@@ -105,26 +105,28 @@ type MethodLimitsResponse struct {
} }
type CreateProviderInstanceRequest struct { type CreateProviderInstanceRequest struct {
ProviderKey string `json:"provider_key"` ProviderKey string `json:"provider_key"`
Name string `json:"name"` Name string `json:"name"`
Config map[string]string `json:"config"` Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"` SupportedTypes []string `json:"supported_types"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
PaymentMode string `json:"payment_mode"` PaymentMode string `json:"payment_mode"`
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
Limits string `json:"limits"` Limits string `json:"limits"`
RefundEnabled bool `json:"refund_enabled"` RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
} }
type UpdateProviderInstanceRequest struct { type UpdateProviderInstanceRequest struct {
Name *string `json:"name"` Name *string `json:"name"`
Config map[string]string `json:"config"` Config map[string]string `json:"config"`
SupportedTypes []string `json:"supported_types"` SupportedTypes []string `json:"supported_types"`
Enabled *bool `json:"enabled"` Enabled *bool `json:"enabled"`
PaymentMode *string `json:"payment_mode"` PaymentMode *string `json:"payment_mode"`
SortOrder *int `json:"sort_order"` SortOrder *int `json:"sort_order"`
Limits *string `json:"limits"` Limits *string `json:"limits"`
RefundEnabled *bool `json:"refund_enabled"` RefundEnabled *bool `json:"refund_enabled"`
AllowUserRefund *bool `json:"allow_user_refund"`
} }
type CreatePlanRequest struct { type CreatePlanRequest struct {
GroupID int64 `json:"group_id"` GroupID int64 `json:"group_id"`
......
...@@ -2,6 +2,7 @@ package service ...@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
...@@ -17,6 +18,19 @@ import ( ...@@ -17,6 +18,19 @@ import (
// --- Refund Flow --- // --- Refund Flow ---
// getOrderProviderInstance looks up the provider instance that processed this order.
// Returns nil, nil for legacy orders without provider_instance_id.
func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) {
if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" {
return nil, nil
}
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if err != nil {
return nil, nil
}
return s.entClient.PaymentProviderInstance.Get(ctx, instID)
}
func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error { func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error {
o, err := s.validateRefundRequest(ctx, oid, uid) o, err := s.validateRefundRequest(ctx, oid, uid)
if err != nil { if err != nil {
...@@ -57,6 +71,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int ...@@ -57,6 +71,14 @@ func (s *PaymentService) validateRefundRequest(ctx context.Context, oid, uid int
if o.Status != OrderStatusCompleted { if o.Status != OrderStatusCompleted {
return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund") return nil, infraerrors.BadRequest("INVALID_STATUS", "only completed orders can request refund")
} }
// Check provider instance allows user refund
inst, err := s.getOrderProviderInstance(ctx, o)
if err != nil || inst == nil {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "refund is not available for this order")
}
if !inst.AllowUserRefund {
return nil, infraerrors.Forbidden("USER_REFUND_DISABLED", "user refund is not enabled for this provider")
}
return o, nil return o, nil
} }
...@@ -69,6 +91,19 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float ...@@ -69,6 +91,19 @@ func (s *PaymentService) PrepareRefund(ctx context.Context, oid int64, amt float
if !psSliceContains(ok, o.Status) { if !psSliceContains(ok, o.Status) {
return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund") return nil, nil, infraerrors.BadRequest("INVALID_STATUS", "order status does not allow refund")
} }
// Check provider instance allows admin refund
inst, instErr := s.getOrderProviderInstance(ctx, o)
if instErr != nil {
slog.Warn("refund: provider instance lookup failed", "orderID", oid, "error", instErr)
return nil, nil, infraerrors.InternalServer("PROVIDER_LOOKUP_FAILED", "failed to look up payment provider for this order")
}
if inst == nil {
// Legacy order without provider_instance_id — block refund
return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not available for this order")
}
if !inst.RefundEnabled {
return nil, nil, infraerrors.Forbidden("REFUND_DISABLED", "refund is not enabled for this provider")
}
if math.IsNaN(amt) || math.IsInf(amt, 0) { if math.IsNaN(amt) || math.IsInf(amt, 0) {
return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount") return nil, nil, infraerrors.BadRequest("INVALID_AMOUNT", "invalid refund amount")
} }
...@@ -150,11 +185,17 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref ...@@ -150,11 +185,17 @@ func (s *PaymentService) ExecuteRefund(ctx context.Context, p *RefundPlan) (*Ref
if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") { if !s.hasAuditLog(ctx, p.OrderID, "REFUND_ROLLBACK_FAILED") {
_, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct) _, err := s.subscriptionSvc.ExtendSubscription(ctx, p.SubscriptionID, -p.SubDaysToDeduct)
if err != nil { if err != nil {
// If deducting would expire the subscription, revoke it entirely if errors.Is(err, ErrAdjustWouldExpire) {
slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct) // Deduction would expire the subscription — revoke it entirely
if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil { slog.Info("subscription deduction would expire, revoking", "orderID", p.OrderID, "subID", p.SubscriptionID, "days", p.SubDaysToDeduct)
if revokeErr := s.subscriptionSvc.RevokeSubscription(ctx, p.SubscriptionID); revokeErr != nil {
s.restoreStatus(ctx, p)
return nil, fmt.Errorf("revoke subscription: %w", revokeErr)
}
} else {
// Other errors (DB failure, not found) — abort refund
s.restoreStatus(ctx, p) s.restoreStatus(ctx, p)
return nil, fmt.Errorf("revoke subscription: %w", revokeErr) return nil, fmt.Errorf("deduct subscription days: %w", err)
} }
} }
} else { } else {
......
...@@ -102,6 +102,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t ...@@ -102,6 +102,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t
}) })
} }
// TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError
// OpenAI OAuth 401 缓存失效出错时仍走 temp_unschedulable
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) { func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
repo := &rateLimitAccountRepoStub{} repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")} invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
...@@ -109,7 +111,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin ...@@ -109,7 +111,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
service.SetTokenCacheInvalidator(invalidator) service.SetTokenCacheInvalidator(invalidator)
account := &Account{ account := &Account{
ID: 101, ID: 101,
Platform: PlatformGemini, Platform: PlatformOpenAI,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
} }
......
...@@ -99,13 +99,19 @@ type DefaultSubscriptionGroupReader interface { ...@@ -99,13 +99,19 @@ type DefaultSubscriptionGroupReader interface {
GetByID(ctx context.Context, id int64) (*Group, error) GetByID(ctx context.Context, id int64) (*Group, error)
} }
// WebSearchManagerBuilder creates a websearch.Manager from config (injected by infra layer).
// proxyURLs maps proxy ID to resolved URL for provider-level proxy support.
type WebSearchManagerBuilder func(cfg *WebSearchEmulationConfig, proxyURLs map[int64]string)
// SettingService 系统设置服务 // SettingService 系统设置服务
type SettingService struct { type SettingService struct {
settingRepo SettingRepository settingRepo SettingRepository
defaultSubGroupReader DefaultSubscriptionGroupReader defaultSubGroupReader DefaultSubscriptionGroupReader
cfg *config.Config proxyRepo ProxyRepository // for resolving websearch provider proxy URLs
onUpdate func() // Callback when settings are updated (for cache invalidation) cfg *config.Config
version string // Application version onUpdate func() // Callback when settings are updated (for cache invalidation)
version string // Application version
webSearchManagerBuilder WebSearchManagerBuilder
} }
// NewSettingService 创建系统设置服务实例 // NewSettingService 创建系统设置服务实例
...@@ -121,6 +127,11 @@ func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscri ...@@ -121,6 +127,11 @@ func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscri
s.defaultSubGroupReader = reader s.defaultSubGroupReader = reader
} }
// SetProxyRepository injects a proxy repo for resolving websearch provider proxy URLs.
func (s *SettingService) SetProxyRepository(repo ProxyRepository) {
s.proxyRepo = repo
}
// GetAllSettings 获取所有系统设置 // GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) { func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx) settings, err := s.settingRepo.GetAll(ctx)
...@@ -168,9 +179,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -168,9 +179,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyCustomEndpoints, SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled, SettingKeyLinuxDoConnectEnabled,
SettingKeyBackendModeEnabled, SettingKeyBackendModeEnabled,
SettingPaymentEnabled,
SettingKeyOIDCConnectEnabled, SettingKeyOIDCConnectEnabled,
SettingKeyOIDCConnectProviderName, SettingKeyOIDCConnectProviderName,
SettingPaymentEnabled, SettingKeyBalanceLowNotifyEnabled,
SettingKeyBalanceLowNotifyThreshold,
SettingKeyBalanceLowNotifyRechargeURL,
SettingKeyAccountQuotaNotifyEnabled,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
...@@ -209,6 +224,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -209,6 +224,11 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
settings[SettingKeyTablePageSizeOptions], settings[SettingKeyTablePageSizeOptions],
) )
var balanceLowNotifyThreshold float64
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
balanceLowNotifyThreshold = v
}
return &PublicSettings{ return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled, EmailVerifyEnabled: emailVerifyEnabled,
...@@ -235,9 +255,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ...@@ -235,9 +255,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomEndpoints: settings[SettingKeyCustomEndpoints], CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled, LinuxDoOAuthEnabled: linuxDoEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled, OIDCOAuthEnabled: oidcEnabled,
OIDCOAuthProviderName: oidcProviderName, OIDCOAuthProviderName: oidcProviderName,
PaymentEnabled: settings[SettingPaymentEnabled] == "true", BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true",
AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true",
BalanceLowNotifyThreshold: balanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings[SettingKeyBalanceLowNotifyRechargeURL],
}, nil }, nil
} }
...@@ -287,10 +311,14 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -287,10 +311,14 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints json.RawMessage `json:"custom_endpoints"` CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version,omitempty"` Version string `json:"version,omitempty"`
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
}{ }{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
...@@ -317,10 +345,14 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ...@@ -317,10 +345,14 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName, OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
PaymentEnabled: settings.PaymentEnabled,
Version: s.version, Version: s.version,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
}, nil }, nil
} }
...@@ -595,6 +627,13 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet ...@@ -595,6 +627,13 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough) updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning) updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
// Balance low notification
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64)
updates[SettingKeyBalanceLowNotifyRechargeURL] = settings.BalanceLowNotifyRechargeURL
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
err = s.settingRepo.SetMultiple(ctx, updates) err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil { if err == nil {
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
...@@ -1217,6 +1256,30 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -1217,6 +1256,30 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true" result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true" result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
// Web search emulation: quick enabled check from the JSON config
if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" {
var wsCfg WebSearchEmulationConfig
if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil {
result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0
}
}
// Balance low notification
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
result.BalanceLowNotifyThreshold = v
}
result.BalanceLowNotifyRechargeURL = settings[SettingKeyBalanceLowNotifyRechargeURL]
// Account quota notification
result.AccountQuotaNotifyEnabled = settings[SettingKeyAccountQuotaNotifyEnabled] == "true"
if raw := strings.TrimSpace(settings[SettingKeyAccountQuotaNotifyEmails]); raw != "" {
result.AccountQuotaNotifyEmails = ParseNotifyEmails(raw)
}
if result.AccountQuotaNotifyEmails == nil {
result.AccountQuotaNotifyEmails = []NotifyEmailEntry{}
}
return result return result
} }
......
...@@ -66,7 +66,7 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis ...@@ -66,7 +66,7 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis
func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) { func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) {
repo := &settingPublicRepoStub{ repo := &settingPublicRepoStub{
values: map[string]string{ values: map[string]string{
SettingKeyTableDefaultPageSize: "50", SettingKeyTableDefaultPageSize: "50",
SettingKeyTablePageSizeOptions: "[20,50,100]", SettingKeyTablePageSizeOptions: "[20,50,100]",
}, },
} }
......
...@@ -208,7 +208,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { ...@@ -208,7 +208,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
svc := NewSettingService(repo, &config.Config{}) svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateSettings(context.Background(), &SystemSettings{ err := svc.UpdateSettings(context.Background(), &SystemSettings{
TableDefaultPageSize: 50, TableDefaultPageSize: 50,
TablePageSizeOptions: []int{20, 50, 100}, TablePageSizeOptions: []int{20, 50, 100},
}) })
require.NoError(t, err) require.NoError(t, err)
...@@ -216,7 +216,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { ...@@ -216,7 +216,7 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) {
require.Equal(t, "[20,50,100]", repo.updates[SettingKeyTablePageSizeOptions]) require.Equal(t, "[20,50,100]", repo.updates[SettingKeyTablePageSizeOptions])
err = svc.UpdateSettings(context.Background(), &SystemSettings{ err = svc.UpdateSettings(context.Background(), &SystemSettings{
TableDefaultPageSize: 1000, TableDefaultPageSize: 1000,
TablePageSizeOptions: []int{20, 100}, TablePageSizeOptions: []int{20, 100},
}) })
require.NoError(t, err) require.NoError(t, err)
......
...@@ -106,6 +106,18 @@ type SystemSettings struct { ...@@ -106,6 +106,18 @@ type SystemSettings struct {
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true) EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true)
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false) EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false)
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false)
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
// Balance low notification
BalanceLowNotifyEnabled bool
BalanceLowNotifyThreshold float64
BalanceLowNotifyRechargeURL string
// Account quota notification
AccountQuotaNotifyEnabled bool
AccountQuotaNotifyEmails []NotifyEmailEntry
} }
type DefaultSubscriptionSetting struct { type DefaultSubscriptionSetting struct {
...@@ -141,10 +153,15 @@ type PublicSettings struct { ...@@ -141,10 +153,15 @@ type PublicSettings struct {
LinuxDoOAuthEnabled bool LinuxDoOAuthEnabled bool
BackendModeEnabled bool BackendModeEnabled bool
PaymentEnabled bool
OIDCOAuthEnabled bool OIDCOAuthEnabled bool
OIDCOAuthProviderName string OIDCOAuthProviderName string
PaymentEnabled bool
Version string Version string
BalanceLowNotifyEnabled bool
AccountQuotaNotifyEnabled bool
BalanceLowNotifyThreshold float64
BalanceLowNotifyRechargeURL string
} }
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
......
...@@ -100,9 +100,22 @@ func valueOrZero(v *int64) int64 { ...@@ -100,9 +100,22 @@ func valueOrZero(v *int64) int64 {
return *v return *v
} }
// AccountQuotaState holds the post-increment quota state returned by the DB transaction.
// All values are post-update (i.e., already include the increment).
type AccountQuotaState struct {
TotalUsed float64
TotalLimit float64
DailyUsed float64
DailyLimit float64
WeeklyUsed float64
WeeklyLimit float64
}
type UsageBillingApplyResult struct { type UsageBillingApplyResult struct {
Applied bool Applied bool
APIKeyQuotaExhausted bool APIKeyQuotaExhausted bool
NewBalance *float64 // post-deduction balance (nil = no balance deduction)
QuotaState *AccountQuotaState // post-increment quota state (nil = no quota increment)
} }
type UsageBillingRepository interface { type UsageBillingRepository interface {
......
...@@ -146,6 +146,8 @@ type UsageLog struct { ...@@ -146,6 +146,8 @@ type UsageLog struct {
RateMultiplier float64 RateMultiplier float64
// AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理) // AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
AccountRateMultiplier *float64 AccountRateMultiplier *float64
// AccountStatsCost 账号统计定价预计算费用(nil = 使用默认公式 total_cost × account_rate_multiplier)
AccountStatsCost *float64
BillingType int8 BillingType int8
RequestType RequestType RequestType RequestType
......
...@@ -30,6 +30,13 @@ type User struct { ...@@ -30,6 +30,13 @@ type User struct {
TotpEnabled bool // 是否启用 TOTP TotpEnabled bool // 是否启用 TOTP
TotpEnabledAt *time.Time // TOTP 启用时间 TotpEnabledAt *time.Time // TOTP 启用时间
// 余额不足通知
BalanceNotifyEnabled bool
BalanceNotifyThresholdType string // "fixed" (default) | "percentage"
BalanceNotifyThreshold *float64
BalanceNotifyExtraEmails []NotifyEmailEntry
TotalRecharged float64
APIKeys []APIKey APIKeys []APIKey
Subscriptions []UserSubscription Subscriptions []UserSubscription
} }
......
...@@ -2,8 +2,10 @@ package service ...@@ -2,8 +2,10 @@ package service
import ( import (
"context" "context"
"crypto/subtle"
"fmt" "fmt"
"log" "log/slog"
"strings"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
...@@ -11,9 +13,18 @@ import ( ...@@ -11,9 +13,18 @@ import (
) )
var ( var (
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
)
const (
maxNotifyEmails = 3 // Maximum number of notification emails per user
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
notifyCodeUserRateWindow = 10 * time.Minute
) )
// UserListFilters contains all filter options for listing users // UserListFilters contains all filter options for listing users
...@@ -58,9 +69,11 @@ type UserRepository interface { ...@@ -58,9 +69,11 @@ type UserRepository interface {
// UpdateProfileRequest 更新用户资料请求 // UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct { type UpdateProfileRequest struct {
Email *string `json:"email"` Email *string `json:"email"`
Username *string `json:"username"` Username *string `json:"username"`
Concurrency *int `json:"concurrency"` Concurrency *int `json:"concurrency"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
} }
// ChangePasswordRequest 修改密码请求 // ChangePasswordRequest 修改密码请求
...@@ -72,14 +85,16 @@ type ChangePasswordRequest struct { ...@@ -72,14 +85,16 @@ type ChangePasswordRequest struct {
// UserService 用户服务 // UserService 用户服务
type UserService struct { type UserService struct {
userRepo UserRepository userRepo UserRepository
settingRepo SettingRepository
authCacheInvalidator APIKeyAuthCacheInvalidator authCacheInvalidator APIKeyAuthCacheInvalidator
billingCache BillingCache billingCache BillingCache
} }
// NewUserService 创建用户服务实例 // NewUserService 创建用户服务实例
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService { func NewUserService(userRepo UserRepository, settingRepo SettingRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
return &UserService{ return &UserService{
userRepo: userRepo, userRepo: userRepo,
settingRepo: settingRepo,
authCacheInvalidator: authCacheInvalidator, authCacheInvalidator: authCacheInvalidator,
billingCache: billingCache, billingCache: billingCache,
} }
...@@ -132,6 +147,17 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat ...@@ -132,6 +147,17 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Concurrency = *req.Concurrency user.Concurrency = *req.Concurrency
} }
if req.BalanceNotifyEnabled != nil {
user.BalanceNotifyEnabled = *req.BalanceNotifyEnabled
}
if req.BalanceNotifyThreshold != nil {
if *req.BalanceNotifyThreshold <= 0 {
user.BalanceNotifyThreshold = nil // clear to system default
} else {
user.BalanceNotifyThreshold = req.BalanceNotifyThreshold
}
}
if err := s.userRepo.Update(ctx, user); err != nil { if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err) return nil, fmt.Errorf("update user: %w", err)
} }
...@@ -198,10 +224,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl ...@@ -198,10 +224,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
} }
if s.billingCache != nil { if s.billingCache != nil {
go func() { go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in balance cache invalidation", "user_id", userID, "recover", r)
}
}()
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil { if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) slog.Error("invalidate user balance cache failed", "user_id", userID, "error", err)
} }
}() }()
} }
...@@ -248,3 +279,229 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error { ...@@ -248,3 +279,229 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
} }
return nil return nil
} }
// SendNotifyEmailCode sends a verification code to the extra notification email.
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error {
if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil {
return err
}
code, err := emailService.GenerateVerifyCode()
if err != nil {
return fmt.Errorf("generate code: %w", err)
}
// Send email first — if SMTP fails, don't write cache or increment counters,
// so the user is not locked out by cooldown/rate-limit for a code they never received.
if err := s.sendNotifyVerifyEmail(ctx, emailService, email, code); err != nil {
return err
}
if err := saveNotifyVerifyCode(ctx, cache, email, code); err != nil {
return err
}
// Increment user-level counter after successful save
if _, err := cache.IncrNotifyCodeUserRate(ctx, userID, notifyCodeUserRateWindow); err != nil {
slog.Error("failed to increment notify code user rate", "user_id", userID, "error", err)
}
return nil
}
// checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit.
func checkNotifyCodeRateLimit(ctx context.Context, cache EmailCache, userID int64, email string) error {
existing, err := cache.GetNotifyVerifyCode(ctx, email)
if err == nil && existing != nil {
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
return ErrVerifyCodeTooFrequent
}
}
count, err := cache.GetNotifyCodeUserRate(ctx, userID)
if err == nil && count >= notifyCodeUserRateLimit {
return ErrNotifyCodeUserRateLimit
}
return nil
}
// saveNotifyVerifyCode saves the verification code to cache.
func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code string) error {
data := &VerificationCodeData{
Code: code,
Attempts: 0,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(verifyCodeTTL),
}
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err)
}
return nil
}
// sendNotifyVerifyEmail builds and sends the verification email.
func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, email, code string) error {
siteName := "Sub2API"
if s.settingRepo != nil {
if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" {
siteName = name
}
}
subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName)
body := buildNotifyVerifyEmailBody(code, siteName)
return emailService.SendEmail(ctx, email, subject, body)
}
// VerifyAndAddNotifyEmail verifies the code and adds the email to user's extra emails.
func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, email, code string, cache EmailCache) error {
if err := verifyNotifyCode(ctx, cache, email, code); err != nil {
return err
}
_ = cache.DeleteNotifyVerifyCode(ctx, email)
return s.addOrVerifyNotifyEmail(ctx, userID, email)
}
// verifyNotifyCode validates the verification code against the cached data.
func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string) error {
data, err := cache.GetNotifyVerifyCode(ctx, email)
if err != nil || data == nil {
return ErrInvalidVerifyCode
}
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++
remaining := time.Until(data.ExpiresAt)
if remaining <= 0 {
return ErrInvalidVerifyCode
}
if err := cache.SetNotifyVerifyCode(ctx, email, data, remaining); err != nil {
slog.Error("failed to update notify verify code attempts", "email", email, "error", err)
}
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
return ErrInvalidVerifyCode
}
return nil
}
// addOrVerifyNotifyEmail adds the email to user's extra notification emails or marks it as verified.
// Note: concurrent calls for the same user could race on the read-modify-write of
// BalanceNotifyExtraEmails. The window is small (requires two verify flows completing
// simultaneously), and the worst case is a duplicate entry which is harmless.
func (s *UserService) addOrVerifyNotifyEmail(ctx context.Context, userID int64, email string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
for i, e := range user.BalanceNotifyExtraEmails {
if strings.EqualFold(e.Email, email) {
if !e.Verified {
user.BalanceNotifyExtraEmails[i].Verified = true
return s.userRepo.Update(ctx, user)
}
return nil // Already verified
}
}
if len(user.BalanceNotifyExtraEmails) >= maxNotifyEmails {
return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d notification emails allowed", maxNotifyEmails))
}
user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, NotifyEmailEntry{
Email: email,
Disabled: false,
Verified: true,
})
return s.userRepo.Update(ctx, user)
}
// RemoveNotifyEmail removes an email from user's extra notification emails.
func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
filtered := make([]NotifyEmailEntry, 0, len(user.BalanceNotifyExtraEmails))
found := false
for _, e := range user.BalanceNotifyExtraEmails {
if strings.EqualFold(e.Email, email) {
found = true
} else {
filtered = append(filtered, e)
}
}
if !found {
return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found")
}
user.BalanceNotifyExtraEmails = filtered
return s.userRepo.Update(ctx, user)
}
// ToggleNotifyEmail toggles the disabled state of a notification email entry.
func (s *UserService) ToggleNotifyEmail(ctx context.Context, userID int64, email string, disabled bool) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
found := false
for i, e := range user.BalanceNotifyExtraEmails {
if strings.EqualFold(e.Email, email) {
user.BalanceNotifyExtraEmails[i].Disabled = disabled
found = true
break
}
}
if !found {
return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found")
}
return s.userRepo.Update(ctx, user)
}
// notifyVerifyEmailTemplate is the HTML template for notify email verification.
// Format args: siteName, code.
const notifyVerifyEmailTemplate = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.code { font-size: 36px; font-weight: bold; letter-spacing: 8px; color: #333; background-color: #f8f9fa; padding: 20px 30px; border-radius: 8px; display: inline-block; margin: 20px 0; font-family: monospace; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">通知邮箱验证码 / Notification Email Verification</p>
<div class="code">%s</div>
<div class="info">
<p>您正在添加额外的通知邮箱,请输入此验证码完成验证。</p>
<p>You are adding an extra notification email. Please enter this code to verify.</p>
<p>此验证码将在 <strong>15 分钟</strong>后失效。</p>
<p>This code will expire in <strong>15 minutes</strong>.</p>
<p>如果您没有请求此验证码,请忽略此邮件。</p>
<p>If you did not request this code, please ignore this email.</p>
</div>
</div>
<div class="footer">
<p>此邮件由系统自动发送,请勿回复。/ This is an automated message, please do not reply.</p>
</div>
</div>
</body>
</html>`
// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification.
func buildNotifyVerifyEmailBody(code, siteName string) string {
return fmt.Sprintf(notifyVerifyEmailTemplate, siteName, code)
}
...@@ -46,12 +46,12 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int ...@@ -46,12 +46,12 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
return 0, nil return 0, nil
} }
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil return nil
} }
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
// --- mock: APIKeyAuthCacheInvalidator --- // --- mock: APIKeyAuthCacheInvalidator ---
...@@ -117,7 +117,7 @@ func (m *mockBillingCache) InvalidateAPIKeyRateLimit(context.Context, int64) err ...@@ -117,7 +117,7 @@ func (m *mockBillingCache) InvalidateAPIKeyRateLimit(context.Context, int64) err
func TestUpdateBalance_Success(t *testing.T) { func TestUpdateBalance_Success(t *testing.T) {
repo := &mockUserRepo{} repo := &mockUserRepo{}
cache := &mockBillingCache{} cache := &mockBillingCache{}
svc := NewUserService(repo, nil, cache) svc := NewUserService(repo, nil, nil, cache)
err := svc.UpdateBalance(context.Background(), 42, 100.0) err := svc.UpdateBalance(context.Background(), 42, 100.0)
require.NoError(t, err) require.NoError(t, err)
...@@ -134,7 +134,7 @@ func TestUpdateBalance_Success(t *testing.T) { ...@@ -134,7 +134,7 @@ func TestUpdateBalance_Success(t *testing.T) {
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) { func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
repo := &mockUserRepo{} repo := &mockUserRepo{}
svc := NewUserService(repo, nil, nil) // billingCache = nil svc := NewUserService(repo, nil, nil, nil) // billingCache = nil
err := svc.UpdateBalance(context.Background(), 1, 50.0) err := svc.UpdateBalance(context.Background(), 1, 50.0)
require.NoError(t, err, "billingCache 为 nil 时不应 panic") require.NoError(t, err, "billingCache 为 nil 时不应 panic")
...@@ -143,7 +143,7 @@ func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) { ...@@ -143,7 +143,7 @@ func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) { func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
repo := &mockUserRepo{} repo := &mockUserRepo{}
cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")} cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")}
svc := NewUserService(repo, nil, cache) svc := NewUserService(repo, nil, nil, cache)
err := svc.UpdateBalance(context.Background(), 99, 200.0) err := svc.UpdateBalance(context.Background(), 99, 200.0)
require.NoError(t, err, "缓存失效失败不应影响主流程返回值") require.NoError(t, err, "缓存失效失败不应影响主流程返回值")
...@@ -157,7 +157,7 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) { ...@@ -157,7 +157,7 @@ func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) { func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")} repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
cache := &mockBillingCache{} cache := &mockBillingCache{}
svc := NewUserService(repo, nil, cache) svc := NewUserService(repo, nil, nil, cache)
err := svc.UpdateBalance(context.Background(), 1, 100.0) err := svc.UpdateBalance(context.Background(), 1, 100.0)
require.Error(t, err, "repo 失败时应返回错误") require.Error(t, err, "repo 失败时应返回错误")
...@@ -173,7 +173,7 @@ func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) { ...@@ -173,7 +173,7 @@ func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) {
repo := &mockUserRepo{} repo := &mockUserRepo{}
auth := &mockAuthCacheInvalidator{} auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{} cache := &mockBillingCache{}
svc := NewUserService(repo, auth, cache) svc := NewUserService(repo, nil, auth, cache)
err := svc.UpdateBalance(context.Background(), 77, 300.0) err := svc.UpdateBalance(context.Background(), 77, 300.0)
require.NoError(t, err) require.NoError(t, err)
...@@ -194,7 +194,7 @@ func TestNewUserService_FieldsAssignment(t *testing.T) { ...@@ -194,7 +194,7 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
auth := &mockAuthCacheInvalidator{} auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{} cache := &mockBillingCache{}
svc := NewUserService(repo, auth, cache) svc := NewUserService(repo, nil, auth, cache)
require.NotNil(t, svc) require.NotNil(t, svc)
require.Equal(t, repo, svc.userRepo) require.Equal(t, repo, svc.userRepo)
require.Equal(t, auth, svc.authCacheInvalidator) require.Equal(t, auth, svc.authCacheInvalidator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment