Unverified Commit 51af8df3 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1731 from touwaeriol/fix/rate-billing-autofill-response-limit

fix: subscription billing, alipay redirect + H5, payment secrets, 128MB response limit
parents 6c73b621 235f7108
...@@ -52,6 +52,11 @@ const ( ...@@ -52,6 +52,11 @@ const (
ConnectionPoolIsolationAccountProxy = "account_proxy" ConnectionPoolIsolationAccountProxy = "account_proxy"
) )
// DefaultUpstreamResponseReadMaxBytes 上游非流式响应体的默认读取上限。
// 128 MB 可容纳 2-3 张 4K PNG(base64 膨胀 33%,单张 4K PNG 最坏约 67MB base64)。
// 可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。
const DefaultUpstreamResponseReadMaxBytes int64 = 128 * 1024 * 1024
type Config struct { type Config struct {
Server ServerConfig `mapstructure:"server"` Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"` Log LogConfig `mapstructure:"log"`
...@@ -1407,7 +1412,7 @@ func setDefaults() { ...@@ -1407,7 +1412,7 @@ func setDefaults() {
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10) viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) viper.SetDefault("gateway.upstream_response_read_max_bytes", DefaultUpstreamResponseReadMaxBytes)
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false) viper.SetDefault("gateway.gemini_debug_response_headers", false)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
......
...@@ -206,6 +206,10 @@ type CreateOrderRequest struct { ...@@ -206,6 +206,10 @@ type CreateOrderRequest struct {
PaymentType string `json:"payment_type" binding:"required"` PaymentType string `json:"payment_type" binding:"required"`
OrderType string `json:"order_type"` OrderType string `json:"order_type"`
PlanID int64 `json:"plan_id"` PlanID int64 `json:"plan_id"`
// IsMobile lets the frontend declare its mobile status directly. When
// nil we fall back to User-Agent heuristics (which miss iPadOS / some
// embedded browsers that strip the "Mobile" keyword).
IsMobile *bool `json:"is_mobile,omitempty"`
} }
// CreateOrder creates a new payment order. // CreateOrder creates a new payment order.
...@@ -222,12 +226,16 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) { ...@@ -222,12 +226,16 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
return return
} }
mobile := isMobile(c)
if req.IsMobile != nil {
mobile = *req.IsMobile
}
result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{ result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{
UserID: subject.UserID, UserID: subject.UserID,
Amount: req.Amount, Amount: req.Amount,
PaymentType: req.PaymentType, PaymentType: req.PaymentType,
ClientIP: c.ClientIP(), ClientIP: c.ClientIP(),
IsMobile: isMobile(c), IsMobile: mobile,
SrcHost: c.Request.Host, SrcHost: c.Request.Host,
SrcURL: c.Request.Referer(), SrcURL: c.Request.Referer(),
OrderType: req.OrderType, OrderType: req.OrderType,
......
...@@ -10,12 +10,20 @@ import ( ...@@ -10,12 +10,20 @@ import (
"strings" "strings"
) )
// AES256KeySize is the required key length (in bytes) for AES-256-GCM.
const AES256KeySize = 32
// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key. // Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
// The output format is "iv:authTag:ciphertext" where each component is base64-encoded, // The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
// matching the Node.js crypto.ts format for cross-compatibility. // matching the Node.js crypto.ts format for cross-compatibility.
//
// Deprecated: payment provider configs are now stored as plaintext JSON.
// This function is kept only for seeding legacy ciphertext in tests and for
// the transitional Decrypt fallback. Scheduled for removal after all live
// deployments complete migration by re-saving their configs.
func Encrypt(plaintext string, key []byte) (string, error) { func Encrypt(plaintext string, key []byte) (string, error) {
if len(key) != 32 { if len(key) != AES256KeySize {
return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key)) return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
} }
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
...@@ -51,9 +59,14 @@ func Encrypt(plaintext string, key []byte) (string, error) { ...@@ -51,9 +59,14 @@ func Encrypt(plaintext string, key []byte) (string, error) {
// Decrypt decrypts a ciphertext string produced by Encrypt. // Decrypt decrypts a ciphertext string produced by Encrypt.
// The input format is "iv:authTag:ciphertext" where each component is base64-encoded. // The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
//
// Deprecated: payment provider configs are now stored as plaintext JSON.
// This function remains only as a read-path fallback for pre-migration
// ciphertext records. Scheduled for removal once all deployments re-save
// their provider configs through the admin UI.
func Decrypt(ciphertext string, key []byte) (string, error) { func Decrypt(ciphertext string, key []byte) (string, error) {
if len(key) != 32 { if len(key) != AES256KeySize {
return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key)) return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
} }
parts := strings.SplitN(ciphertext, ":", 3) parts := strings.SplitN(ciphertext, ":", 3)
......
...@@ -261,6 +261,9 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns ...@@ -261,6 +261,9 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
if err != nil { if err != nil {
return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err) return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err)
} }
if config == nil {
config = map[string]string{}
}
if selected.PaymentMode != "" { if selected.PaymentMode != "" {
config["paymentMode"] = selected.PaymentMode config["paymentMode"] = selected.PaymentMode
...@@ -275,16 +278,36 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns ...@@ -275,16 +278,36 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
}, nil }, nil
} }
func (lb *DefaultLoadBalancer) decryptConfig(encrypted string) (map[string]string, error) { // decryptConfig parses a stored provider config.
plaintext, err := Decrypt(encrypted, lb.encryptionKey) // New records are plaintext JSON; legacy records are AES-256-GCM ciphertext.
if err != nil { // Unreadable values (legacy ciphertext without a valid key, or malformed data)
return nil, err // are treated as empty so the service keeps running while the admin re-enters
// the config via the UI.
//
// TODO(deprecated-legacy-ciphertext): The AES fallback branch below is a
// transitional compatibility shim for pre-plaintext records. Remove it (and
// the encryptionKey field + the Decrypt import) after a few releases once all
// live deployments have re-saved their provider configs through the UI.
func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) {
if stored == "" {
return nil, nil
} }
var config map[string]string var config map[string]string
if err := json.Unmarshal([]byte(plaintext), &config); err != nil { if err := json.Unmarshal([]byte(stored), &config); err == nil {
return nil, fmt.Errorf("unmarshal config: %w", err) return config, nil
}
// Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
if len(lb.encryptionKey) == AES256KeySize {
//nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil {
if err := json.Unmarshal([]byte(plaintext), &config); err == nil {
return config, nil
}
}
} }
return config, nil slog.Warn("payment provider config unreadable, treating as empty for re-entry",
"stored_len", len(stored))
return nil, nil
} }
// GetInstanceDailyAmount returns the total completed order amount for an instance today. // GetInstanceDailyAmount returns the total completed order amount for an instance today.
......
...@@ -452,6 +452,103 @@ func TestStartOfDay(t *testing.T) { ...@@ -452,6 +452,103 @@ func TestStartOfDay(t *testing.T) {
} }
} }
func TestDecryptConfig_PlaintextAndLegacyCompat(t *testing.T) {
t.Parallel()
key := make([]byte, AES256KeySize)
for i := range key {
key[i] = byte(i + 1)
}
wrongKey := make([]byte, AES256KeySize)
for i := range wrongKey {
wrongKey[i] = byte(0xFF - i)
}
plaintextJSON := `{"appId":"app-123","secret":"sec-xyz"}`
legacyEncrypted, err := Encrypt(plaintextJSON, key)
if err != nil {
t.Fatalf("seed Encrypt: %v", err)
}
tests := []struct {
name string
stored string
key []byte
want map[string]string
}{
{
name: "empty stored returns nil map",
stored: "",
key: key,
want: nil,
},
{
name: "plaintext JSON parses directly",
stored: plaintextJSON,
key: nil,
want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
},
{
name: "plaintext JSON works even with key present",
stored: plaintextJSON,
key: key,
want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
},
{
name: "legacy ciphertext with correct key decrypts",
stored: legacyEncrypted,
key: key,
want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
},
{
name: "legacy ciphertext with no key treated as empty",
stored: legacyEncrypted,
key: nil,
want: nil,
},
{
name: "legacy ciphertext with wrong key treated as empty",
stored: legacyEncrypted,
key: wrongKey,
want: nil,
},
{
name: "garbage data treated as empty",
stored: "not-json-and-not-ciphertext",
key: key,
want: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
lb := NewDefaultLoadBalancer(nil, tt.key)
got, err := lb.decryptConfig(tt.stored)
if err != nil {
t.Fatalf("decryptConfig unexpected error: %v", err)
}
if !stringMapEqual(got, tt.want) {
t.Fatalf("decryptConfig = %v, want %v", got, tt.want)
}
})
}
}
// stringMapEqual compares two map[string]string values; nil and empty are equal.
func stringMapEqual(a, b map[string]string) bool {
if len(a) != len(b) {
return false
}
for k, v := range a {
if bv, ok := b[k]; !ok || bv != v {
return false
}
}
return true
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Helpers // Helpers
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
......
...@@ -15,8 +15,8 @@ import ( ...@@ -15,8 +15,8 @@ import (
// Alipay product codes. // Alipay product codes.
const ( const (
alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
alipayProductCodeWapPay = "QUICK_WAP_WAY" alipayProductCodeWapPay = "QUICK_WAP_WAY"
alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
) )
// Alipay response constants. // Alipay response constants.
...@@ -79,7 +79,12 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType { ...@@ -79,7 +79,12 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType {
return []payment.PaymentType{payment.TypeAlipay} return []payment.PaymentType{payment.TypeAlipay}
} }
// CreatePayment creates an Alipay payment page URL. // CreatePayment creates an Alipay payment using redirect-only flow:
// - Mobile (H5): alipay.trade.wap.pay — returns a URL the browser jumps to.
// - PC: alipay.trade.page.pay — returns a gateway URL the browser opens in a
// new window; Alipay's own page then shows login/QR. We intentionally do
// NOT encode the URL into a QR on the client (it isn't a scannable payload
// and would produce an invalid scan result).
func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient() client, err := a.getClient()
if err != nil { if err != nil {
...@@ -96,31 +101,31 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque ...@@ -96,31 +101,31 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque
} }
if req.IsMobile { if req.IsMobile {
return a.createTrade(client, req, notifyURL, returnURL, true) return a.createWapTrade(client, req, notifyURL, returnURL)
} }
return a.createTrade(client, req, notifyURL, returnURL, false) return a.createPagePayTrade(client, req, notifyURL, returnURL)
} }
func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) { func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
if isMobile { param := alipay.TradeWapPay{}
param := alipay.TradeWapPay{} param.OutTradeNo = req.OrderID
param.OutTradeNo = req.OrderID param.TotalAmount = req.Amount
param.TotalAmount = req.Amount param.Subject = req.Subject
param.Subject = req.Subject param.ProductCode = alipayProductCodeWapPay
param.ProductCode = alipayProductCodeWapPay param.NotifyURL = notifyURL
param.NotifyURL = notifyURL param.ReturnURL = returnURL
param.ReturnURL = returnURL
payURL, err := client.TradeWapPay(param)
payURL, err := client.TradeWapPay(param) if err != nil {
if err != nil { return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
}
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
PayURL: payURL.String(),
}, nil
} }
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
PayURL: payURL.String(),
}, nil
}
func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
param := alipay.TradePagePay{} param := alipay.TradePagePay{}
param.OutTradeNo = req.OrderID param.OutTradeNo = req.OrderID
param.TotalAmount = req.Amount param.TotalAmount = req.Amount
...@@ -136,7 +141,6 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq ...@@ -136,7 +141,6 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq
return &payment.CreatePaymentResponse{ return &payment.CreatePaymentResponse{
TradeNo: req.OrderID, TradeNo: req.OrderID,
PayURL: payURL.String(), PayURL: payURL.String(),
QRCode: payURL.String(),
}, nil }, nil
} }
......
...@@ -586,6 +586,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI ...@@ -586,6 +586,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI
} }
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
// 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率)
if input.GroupRates != nil {
for groupID, rate := range input.GroupRates {
if rate != nil && *rate <= 0 {
return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID)
}
}
}
user, err := s.userRepo.GetByID(ctx, id) user, err := s.userRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -811,6 +820,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro ...@@ -811,6 +820,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro
} }
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
if input.RateMultiplier <= 0 {
return nil, errors.New("rate_multiplier must be > 0")
}
platform := input.Platform platform := input.Platform
if platform == "" { if platform == "" {
platform = PlatformAnthropic platform = PlatformAnthropic
...@@ -1050,6 +1063,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -1050,6 +1063,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.Platform = input.Platform group.Platform = input.Platform
} }
if input.RateMultiplier != nil { if input.RateMultiplier != nil {
if *input.RateMultiplier <= 0 {
return nil, errors.New("rate_multiplier must be > 0")
}
group.RateMultiplier = *input.RateMultiplier group.RateMultiplier = *input.RateMultiplier
} }
if input.IsExclusive != nil { if input.IsExclusive != nil {
...@@ -1286,6 +1302,11 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro ...@@ -1286,6 +1302,11 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
if s.userGroupRateRepo == nil { if s.userGroupRateRepo == nil {
return nil return nil
} }
for _, e := range entries {
if e.RateMultiplier <= 0 {
return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID)
}
}
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
} }
......
...@@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo ...@@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1", Name: "g1",
Platform: PlatformOpenAI, Platform: PlatformOpenAI,
RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID, FallbackGroupIDOnInvalidRequest: &fallbackID,
}) })
...@@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t ...@@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1", Name: "g1",
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeSubscription, SubscriptionType: SubscriptionTypeSubscription,
FallbackGroupIDOnInvalidRequest: &fallbackID, FallbackGroupIDOnInvalidRequest: &fallbackID,
}) })
...@@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t * ...@@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1", Name: "g1",
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID, FallbackGroupIDOnInvalidRequest: &fallbackID,
}) })
...@@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) { ...@@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1", Name: "g1",
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID, FallbackGroupIDOnInvalidRequest: &fallbackID,
}) })
...@@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes ...@@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1", Name: "g1",
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID, FallbackGroupIDOnInvalidRequest: &fallbackID,
}) })
...@@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing. ...@@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1", Name: "g1",
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &zero, FallbackGroupIDOnInvalidRequest: &zero,
}) })
......
...@@ -448,8 +448,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, ...@@ -448,8 +448,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
}) })
} }
if input.RateMultiplier <= 0 { // 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。
input.RateMultiplier = 1.0 if input.RateMultiplier < 0 {
input.RateMultiplier = 0
} }
var breakdown *CostBreakdown var breakdown *CostBreakdown
...@@ -493,8 +494,9 @@ func (s *BillingService) computeTokenBreakdown( ...@@ -493,8 +494,9 @@ func (s *BillingService) computeTokenBreakdown(
rateMultiplier float64, serviceTier string, rateMultiplier float64, serviceTier string,
applyLongCtx bool, applyLongCtx bool,
) *CostBreakdown { ) *CostBreakdown {
if rateMultiplier <= 0 { // 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。
rateMultiplier = 1.0 if rateMultiplier < 0 {
rateMultiplier = 0
} }
inputPrice := pricing.InputPricePerToken inputPrice := pricing.InputPricePerToken
...@@ -831,9 +833,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag ...@@ -831,9 +833,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
// 计算总费用 // 计算总费用
totalCost := unitPrice * float64(imageCount) totalCost := unitPrice * float64(imageCount)
// 应用倍率 // 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣)
if rateMultiplier <= 0 { if rateMultiplier < 0 {
rateMultiplier = 1.0 rateMultiplier = 0
} }
actualCost := totalCost * rateMultiplier actualCost := totalCost * rateMultiplier
......
...@@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) { ...@@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) {
require.Equal(t, 0.0, cost.ActualCost) require.Equal(t, 0.0, cost.ActualCost)
} }
// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0 // TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费
// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) { func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
svc := &BillingService{} svc := &BillingService{}
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0) cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
} }
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格 // TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
......
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被
// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。
func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
tests := []struct {
name string
multiplier float64
wantRatio float64 // ActualCost / TotalCost
}{
{"negative clamped to 0", -1.5, 0},
{"zero passes through as 0 (defense in depth)", 0, 0},
{"positive 2x applied", 2.0, 2.0},
{"positive 0.5x applied", 0.5, 0.5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier)
require.NoError(t, err)
require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero")
require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
})
}
}
// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径
// 同样遵循"负数 → 0"语义。
func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
svc := newTestBillingService()
price := 0.04
cfg := &ImagePriceConfig{Price1K: &price}
tests := []struct {
name string
multiplier float64
wantRatio float64
}{
{"negative clamped to 0", -0.5, 0},
{"zero passes through", 0, 0},
{"positive 3x applied", 3.0, 3.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier)
require.NotNil(t, cost)
require.Greater(t, cost.TotalCost, 0.0)
require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
})
}
}
...@@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) { ...@@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) {
require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10) require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10)
} }
func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000}
costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0)
require.NoError(t, err)
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
}
func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000}
costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0)
require.NoError(t, err)
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
}
func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) { func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) {
svc := newTestBillingService() svc := newTestBillingService()
......
...@@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) { ...@@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) {
require.Equal(t, string(BillingModeImage), cost.BillingMode) require.Equal(t, string(BillingModeImage), cost.BillingMode)
} }
func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) { // TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为:
// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。
func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) {
bs := newTestBillingService() bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs) resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
costZero, err := bs.CalculateCostUnified(CostInput{ cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 0, // should default to 1.0
Resolver: resolver,
})
require.NoError(t, err)
costOne, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(), Ctx: context.Background(),
Model: "claude-sonnet-4", Model: "claude-sonnet-4",
Tokens: tokens, Tokens: tokens,
RateMultiplier: 1.0, RateMultiplier: 0,
Resolver: resolver, Resolver: resolver,
}) })
require.NoError(t, err) require.NoError(t, err)
require.Greater(t, cost.TotalCost, 0.0)
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
} }
func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) { // TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为:
// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。
func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) {
bs := newTestBillingService() bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs) resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000} tokens := UsageTokens{InputTokens: 1000}
costNeg, err := bs.CalculateCostUnified(CostInput{ cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(), Ctx: context.Background(),
Model: "claude-sonnet-4", Model: "claude-sonnet-4",
Tokens: tokens, Tokens: tokens,
...@@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) ...@@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T)
Resolver: resolver, Resolver: resolver,
}) })
require.NoError(t, err) require.NoError(t, err)
require.Greater(t, cost.TotalCost, 0.0)
costOne, err := bs.CalculateCostUnified(CostInput{ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.0,
Resolver: resolver,
})
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
} }
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) { func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
......
...@@ -7317,8 +7317,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill ...@@ -7317,8 +7317,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
cost := p.Cost cost := p.Cost
if p.IsSubscriptionBill { if p.IsSubscriptionBill {
if cost.TotalCost > 0 { // Subscription usage tracked by ActualCost so group rate multiplier
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { // consumes the quota at the expected speed.
if cost.ActualCost > 0 {
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
} }
} }
...@@ -7417,9 +7419,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage ...@@ -7417,9 +7419,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
} }
} }
// Record subscription / balance cost using ActualCost so the group (and any
// user-specific) rate multiplier consumes subscription quota at the expected
// speed. TotalCost remains the raw (pre-multiplier) value; downstream guards
// on "> 0" still correctly skip free subscriptions (RateMultiplier == 0).
if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 {
cmd.SubscriptionID = &p.Subscription.ID cmd.SubscriptionID = &p.Subscription.ID
cmd.SubscriptionCost = p.Cost.TotalCost cmd.SubscriptionCost = p.Cost.ActualCost
} else if p.Cost.ActualCost > 0 { } else if p.Cost.ActualCost > 0 {
cmd.BalanceCost = p.Cost.ActualCost cmd.BalanceCost = p.Cost.ActualCost
} }
...@@ -7478,8 +7484,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu ...@@ -7478,8 +7484,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu
} }
if p.IsSubscriptionBill { if p.IsSubscriptionBill {
if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost) deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost)
} }
} else if p.Cost.ActualCost > 0 && p.User != nil { } else if p.Cost.ActualCost > 0 && p.User != nil {
deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost)
......
//go:build unit
package service
import (
"testing"
)
// TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier locks in the fix
// that subscription-mode billing honours the group (and any user-specific) rate
// multiplier — i.e. cmd.SubscriptionCost tracks ActualCost (= TotalCost *
// RateMultiplier), not raw TotalCost.
func TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier(t *testing.T) {
t.Parallel()
groupID := int64(7)
subID := int64(42)
tests := []struct {
name string
totalCost float64
actualCost float64
isSubscription bool
wantSub float64
wantBalance float64
}{
{
name: "subscription with 2x multiplier consumes 2x quota",
totalCost: 1.0,
actualCost: 2.0,
isSubscription: true,
wantSub: 2.0,
wantBalance: 0,
},
{
name: "subscription with 0.5x multiplier consumes 0.5x quota",
totalCost: 1.0,
actualCost: 0.5,
isSubscription: true,
wantSub: 0.5,
wantBalance: 0,
},
{
name: "free subscription (multiplier 0) consumes no quota",
totalCost: 1.0,
actualCost: 0,
isSubscription: true,
wantSub: 0,
wantBalance: 0,
},
{
name: "balance billing keeps using ActualCost (regression)",
totalCost: 1.0,
actualCost: 2.0,
isSubscription: false,
wantSub: 0,
wantBalance: 2.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
p := &postUsageBillingParams{
Cost: &CostBreakdown{TotalCost: tt.totalCost, ActualCost: tt.actualCost},
User: &User{ID: 1},
APIKey: &APIKey{ID: 2, GroupID: &groupID},
Account: &Account{ID: 3},
Subscription: &UserSubscription{ID: subID},
IsSubscriptionBill: tt.isSubscription,
}
cmd := buildUsageBillingCommand("req-1", nil, p)
if cmd == nil {
t.Fatal("buildUsageBillingCommand returned nil")
}
if cmd.SubscriptionCost != tt.wantSub {
t.Errorf("SubscriptionCost = %v, want %v", cmd.SubscriptionCost, tt.wantSub)
}
if cmd.BalanceCost != tt.wantBalance {
t.Errorf("BalanceCost = %v, want %v", cmd.BalanceCost, tt.wantBalance)
}
})
}
}
...@@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool { ...@@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription return g.SubscriptionType == SubscriptionTypeSubscription
} }
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
func (g *Group) HasDailyLimit() bool { func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0 return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
} }
......
...@@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel ...@@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel
Model: "gpt-5.1", Model: "gpt-5.1",
Duration: time.Second, Duration: time.Second,
}, },
APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}}, APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}},
User: &User{ID: 200}, User: &User{ID: 200},
Account: &Account{ID: 300}, Account: &Account{ID: 300},
Subscription: subscription, Subscription: subscription,
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"strconv" "strconv"
"strings" "strings"
...@@ -51,7 +52,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte ...@@ -51,7 +52,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
AllowUserRefund: inst.AllowUserRefund, AllowUserRefund: inst.AllowUserRefund,
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
} }
resp.Config, err = s.decryptAndMaskConfig(inst.Config) resp.Config, err = s.decryptAndMaskConfig(inst.ProviderKey, inst.Config)
if err != nil { if err != nil {
return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err) return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err)
} }
...@@ -60,8 +61,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte ...@@ -60,8 +61,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
return result, nil return result, nil
} }
func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) { // decryptAndMaskConfig returns the stored config with sensitive fields omitted.
return s.decryptConfig(encrypted) // Admin UIs display masked placeholders for these; the raw values never leave
// the server. Callers that need the full config (e.g. payment runtime) must
// use decryptConfig directly.
func (s *PaymentConfigService) decryptAndMaskConfig(providerKey, encrypted string) (map[string]string, error) {
cfg, err := s.decryptConfig(encrypted)
if err != nil {
return nil, err
}
if cfg == nil {
return nil, nil
}
masked := make(map[string]string, len(cfg))
for k, v := range cfg {
if isSensitiveProviderConfigField(providerKey, k) {
continue
}
masked[k] = v
}
return masked, nil
} }
// pendingOrderStatuses are order statuses considered "in progress". // pendingOrderStatuses are order statuses considered "in progress".
...@@ -71,16 +90,27 @@ var pendingOrderStatuses = []string{ ...@@ -71,16 +90,27 @@ var pendingOrderStatuses = []string{
payment.OrderStatusRecharging, payment.OrderStatusRecharging,
} }
var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"} // providerSensitiveConfigFields is the authoritative list of config keys that
// are treated as secrets per provider. Must stay in sync with the frontend
// definition at frontend/src/components/payment/providerConfig.ts
// (PROVIDER_CONFIG_FIELDS, fields with sensitive: true).
//
// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl,
// stripe publishableKey) are returned in plaintext by the admin GET API.
var providerSensitiveConfigFields = map[string]map[string]struct{}{
payment.TypeEasyPay: {"pkey": {}},
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
}
func isSensitiveConfigField(fieldName string) bool { func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
lower := strings.ToLower(fieldName) fields, ok := providerSensitiveConfigFields[providerKey]
for _, p := range sensitiveConfigPatterns { if !ok {
if strings.Contains(lower, p) { return false
return true
}
} }
return false _, found := fields[strings.ToLower(fieldName)]
return found
} }
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) { func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
...@@ -136,10 +166,26 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error { ...@@ -136,10 +166,26 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error {
// NOTE: This function exceeds 30 lines due to per-field nil-check patch update // NOTE: This function exceeds 30 lines due to per-field nil-check patch update
// boilerplate and pending-order safety checks. // boilerplate and pending-order safety checks.
func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) { func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
var cachedInst *dbent.PaymentProviderInstance
loadInst := func() (*dbent.PaymentProviderInstance, error) {
if cachedInst != nil {
return cachedInst, nil
}
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("load provider instance: %w", err)
}
cachedInst = inst
return inst, nil
}
if req.Config != nil { if req.Config != nil {
inst, err := loadInst()
if err != nil {
return nil, err
}
hasSensitive := false hasSensitive := false
for k := range req.Config { for k, v := range req.Config {
if isSensitiveConfigField(k) && req.Config[k] != "" { if v != "" && isSensitiveProviderConfigField(inst.ProviderKey, k) {
hasSensitive = true hasSensitive = true
break break
} }
...@@ -282,27 +328,48 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon ...@@ -282,27 +328,48 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon
return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err) return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err)
} }
if existing == nil { if existing == nil {
return newConfig, nil existing = map[string]string{}
} }
for k, v := range newConfig { for k, v := range newConfig {
// Preserve existing secrets when the client submits an empty value
// (admin UI omits the value to indicate "leave unchanged").
if v == "" && isSensitiveProviderConfigField(inst.ProviderKey, k) {
continue
}
existing[k] = v existing[k] = v
} }
return existing, nil return existing, nil
} }
func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) { // decryptConfig parses a stored provider config.
if encrypted == "" { // New records are plaintext JSON; legacy records are AES-256-GCM ciphertext
// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including
// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty,
// letting the admin re-enter the config via the UI to complete the migration.
//
// TODO(deprecated-legacy-ciphertext): The AES fallback branch is a transitional
// shim for pre-plaintext records. Remove it (and the encryptionKey field) after
// a few releases once all live deployments have re-saved their provider configs.
func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) {
if stored == "" {
return nil, nil return nil, nil
} }
decrypted, err := payment.Decrypt(encrypted, s.encryptionKey) var cfg map[string]string
if err != nil { if err := json.Unmarshal([]byte(stored), &cfg); err == nil {
return nil, fmt.Errorf("decrypt config: %w", err) return cfg, nil
} }
var raw map[string]string // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
if err := json.Unmarshal([]byte(decrypted), &raw); err != nil { if len(s.encryptionKey) == payment.AES256KeySize {
return nil, fmt.Errorf("unmarshal decrypted config: %w", err) //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil {
if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil {
return cfg, nil
}
}
} }
return raw, nil slog.Warn("payment provider config unreadable, treating as empty for re-entry",
"stored_len", len(stored))
return nil, nil
} }
func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error { func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error {
...@@ -317,14 +384,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in ...@@ -317,14 +384,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in
return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx) return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx)
} }
// encryptConfig serialises a provider config for storage.
// New records are written as plaintext JSON; the historical AES-GCM wrapping
// has been dropped but decryptConfig still accepts old ciphertext during migration.
func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) { func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) {
data, err := json.Marshal(cfg) data, err := json.Marshal(cfg)
if err != nil { if err != nil {
return "", fmt.Errorf("marshal config: %w", err) return "", fmt.Errorf("marshal config: %w", err)
} }
enc, err := payment.Encrypt(string(data), s.encryptionKey) return string(data), nil
if err != nil {
return "", fmt.Errorf("encrypt config: %w", err)
}
return enc, nil
} }
...@@ -97,41 +97,52 @@ func TestValidateProviderRequest(t *testing.T) { ...@@ -97,41 +97,52 @@ func TestValidateProviderRequest(t *testing.T) {
} }
} }
func TestIsSensitiveConfigField(t *testing.T) { func TestIsSensitiveProviderConfigField(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
field string providerKey string
wantSen bool field string
wantSen bool
}{ }{
// Sensitive fields (contain key/secret/private/password/pkey patterns) // Stripe: publishableKey is public, only secretKey/webhookSecret are secrets
{"secretKey", true}, {"stripe", "secretKey", true},
{"apiSecret", true}, {"stripe", "webhookSecret", true},
{"pkey", true}, {"stripe", "SecretKey", true}, // case-insensitive
{"privateKey", true}, {"stripe", "publishableKey", false},
{"apiPassword", true}, {"stripe", "appId", false},
{"appKey", true},
{"SECRET_TOKEN", true}, // Alipay
{"PrivateData", true}, {"alipay", "privateKey", true},
{"PASSWORD", true}, {"alipay", "publicKey", true},
{"mySecretValue", true}, {"alipay", "alipayPublicKey", true},
{"alipay", "appId", false},
// Non-sensitive fields {"alipay", "notifyUrl", false},
{"appId", false},
{"mchId", false}, // Wxpay
{"apiBase", false}, {"wxpay", "privateKey", true},
{"endpoint", false}, {"wxpay", "apiV3Key", true},
{"merchantNo", false}, {"wxpay", "publicKey", true},
{"paymentMode", false}, {"wxpay", "publicKeyId", false},
{"notifyUrl", false}, {"wxpay", "certSerial", false},
{"wxpay", "mchId", false},
// EasyPay
{"easypay", "pkey", true},
{"easypay", "pid", false},
{"easypay", "apiBase", false},
// Unknown provider: never sensitive
{"unknown", "secretKey", false},
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.field, func(t *testing.T) { tc := tc
t.Run(tc.providerKey+"/"+tc.field, func(t *testing.T) {
t.Parallel() t.Parallel()
got := isSensitiveConfigField(tc.field) got := isSensitiveProviderConfigField(tc.providerKey, tc.field)
assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field) assert.Equal(t, tc.wantSen, got, "isSensitiveProviderConfigField(%q, %q)", tc.providerKey, tc.field)
}) })
} }
} }
......
...@@ -12,7 +12,9 @@ import ( ...@@ -12,7 +12,9 @@ import (
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large") var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024 // defaultUpstreamResponseReadMaxBytes 源自 config.DefaultUpstreamResponseReadMaxBytes,
// 仅在 cfg 为 nil 时作为兜底(测试或极端场景)。
const defaultUpstreamResponseReadMaxBytes = config.DefaultUpstreamResponseReadMaxBytes
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 { func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 { if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment