Commit da1d2600 authored by IanShaw027's avatar IanShaw027
Browse files

Merge branch 'main' into rebuild/auth-identity-foundation

parents e4cfcae6 78f691d2
...@@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo ...@@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1", Name: "g1",
Platform: PlatformOpenAI, Platform: PlatformOpenAI,
RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID, FallbackGroupIDOnInvalidRequest: &fallbackID,
}) })
...@@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t ...@@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1", Name: "g1",
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeSubscription, SubscriptionType: SubscriptionTypeSubscription,
FallbackGroupIDOnInvalidRequest: &fallbackID, FallbackGroupIDOnInvalidRequest: &fallbackID,
}) })
...@@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t * ...@@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1", Name: "g1",
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID, FallbackGroupIDOnInvalidRequest: &fallbackID,
}) })
...@@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) { ...@@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1", Name: "g1",
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID, FallbackGroupIDOnInvalidRequest: &fallbackID,
}) })
...@@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes ...@@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1", Name: "g1",
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID, FallbackGroupIDOnInvalidRequest: &fallbackID,
}) })
...@@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing. ...@@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1", Name: "g1",
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &zero, FallbackGroupIDOnInvalidRequest: &zero,
}) })
......
...@@ -203,17 +203,6 @@ func (s *BillingService) initFallbackPricing() { ...@@ -203,17 +203,6 @@ func (s *BillingService) initFallbackPricing() {
SupportsCacheBreakdown: false, SupportsCacheBreakdown: false,
} }
// OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
s.fallbackPrices["gpt-5.1"] = &ModelPricing{
InputPricePerToken: 1.25e-6, // $1.25 per MTok
InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok
OutputPricePerToken: 10e-6, // $10 per MTok
OutputPricePerTokenPriority: 20e-6, // $20 per MTok
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
CacheReadPricePerToken: 0.125e-6,
CacheReadPricePerTokenPriority: 0.25e-6,
SupportsCacheBreakdown: false,
}
// OpenAI GPT-5.4(业务指定价格) // OpenAI GPT-5.4(业务指定价格)
s.fallbackPrices["gpt-5.4"] = &ModelPricing{ s.fallbackPrices["gpt-5.4"] = &ModelPricing{
InputPricePerToken: 2.5e-6, // $2.5 per MTok InputPricePerToken: 2.5e-6, // $2.5 per MTok
...@@ -234,12 +223,6 @@ func (s *BillingService) initFallbackPricing() { ...@@ -234,12 +223,6 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerToken: 7.5e-8, CacheReadPricePerToken: 7.5e-8,
SupportsCacheBreakdown: false, SupportsCacheBreakdown: false,
} }
s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{
InputPricePerToken: 2e-7,
OutputPricePerToken: 1.25e-6,
CacheReadPricePerToken: 2e-8,
SupportsCacheBreakdown: false,
}
// OpenAI GPT-5.2(本地兜底) // OpenAI GPT-5.2(本地兜底)
s.fallbackPrices["gpt-5.2"] = &ModelPricing{ s.fallbackPrices["gpt-5.2"] = &ModelPricing{
InputPricePerToken: 1.75e-6, InputPricePerToken: 1.75e-6,
...@@ -251,8 +234,8 @@ func (s *BillingService) initFallbackPricing() { ...@@ -251,8 +234,8 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerTokenPriority: 0.35e-6, CacheReadPricePerTokenPriority: 0.35e-6,
SupportsCacheBreakdown: false, SupportsCacheBreakdown: false,
} }
// Codex 族兜底统一按 GPT-5.1 Codex 价格计费 // Codex 族兜底统一按 GPT-5.3 Codex 价格计费
s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{ s.fallbackPrices["gpt-5.3-codex"] = &ModelPricing{
InputPricePerToken: 1.5e-6, // $1.5 per MTok InputPricePerToken: 1.5e-6, // $1.5 per MTok
InputPricePerTokenPriority: 3e-6, // $3 per MTok InputPricePerTokenPriority: 3e-6, // $3 per MTok
OutputPricePerToken: 12e-6, // $12 per MTok OutputPricePerToken: 12e-6, // $12 per MTok
...@@ -262,17 +245,6 @@ func (s *BillingService) initFallbackPricing() { ...@@ -262,17 +245,6 @@ func (s *BillingService) initFallbackPricing() {
CacheReadPricePerTokenPriority: 0.3e-6, CacheReadPricePerTokenPriority: 0.3e-6,
SupportsCacheBreakdown: false, SupportsCacheBreakdown: false,
} }
s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{
InputPricePerToken: 1.75e-6,
InputPricePerTokenPriority: 3.5e-6,
OutputPricePerToken: 14e-6,
OutputPricePerTokenPriority: 28e-6,
CacheCreationPricePerToken: 1.75e-6,
CacheReadPricePerToken: 0.175e-6,
CacheReadPricePerTokenPriority: 0.35e-6,
SupportsCacheBreakdown: false,
}
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
} }
// getFallbackPricing 根据模型系列获取回退价格 // getFallbackPricing 根据模型系列获取回退价格
...@@ -318,20 +290,12 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { ...@@ -318,20 +290,12 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
switch normalized { switch normalized {
case "gpt-5.4-mini": case "gpt-5.4-mini":
return s.fallbackPrices["gpt-5.4-mini"] return s.fallbackPrices["gpt-5.4-mini"]
case "gpt-5.4-nano":
return s.fallbackPrices["gpt-5.4-nano"]
case "gpt-5.4": case "gpt-5.4":
return s.fallbackPrices["gpt-5.4"] return s.fallbackPrices["gpt-5.4"]
case "gpt-5.2": case "gpt-5.2":
return s.fallbackPrices["gpt-5.2"] return s.fallbackPrices["gpt-5.2"]
case "gpt-5.2-codex": case "gpt-5.3-codex", "gpt-5.3-codex-spark":
return s.fallbackPrices["gpt-5.2-codex"]
case "gpt-5.3-codex":
return s.fallbackPrices["gpt-5.3-codex"] return s.fallbackPrices["gpt-5.3-codex"]
case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
return s.fallbackPrices["gpt-5.1-codex"]
case "gpt-5.1":
return s.fallbackPrices["gpt-5.1"]
} }
} }
...@@ -448,8 +412,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, ...@@ -448,8 +412,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown,
}) })
} }
if input.RateMultiplier <= 0 { // 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。
input.RateMultiplier = 1.0 if input.RateMultiplier < 0 {
input.RateMultiplier = 0
} }
var breakdown *CostBreakdown var breakdown *CostBreakdown
...@@ -493,8 +458,9 @@ func (s *BillingService) computeTokenBreakdown( ...@@ -493,8 +458,9 @@ func (s *BillingService) computeTokenBreakdown(
rateMultiplier float64, serviceTier string, rateMultiplier float64, serviceTier string,
applyLongCtx bool, applyLongCtx bool,
) *CostBreakdown { ) *CostBreakdown {
if rateMultiplier <= 0 { // 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。
rateMultiplier = 1.0 if rateMultiplier < 0 {
rateMultiplier = 0
} }
inputPrice := pricing.InputPricePerToken inputPrice := pricing.InputPricePerToken
...@@ -665,8 +631,13 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens ...@@ -665,8 +631,13 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens
} }
func isOpenAIGPT54Model(model string) bool { func isOpenAIGPT54Model(model string) bool {
normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model))) trimmed := strings.TrimSpace(strings.ToLower(model))
return normalized == "gpt-5.4" // 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel
// 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
return false
}
return normalizeCodexModel(trimmed) == "gpt-5.4"
} }
// CalculateCostWithConfig 使用配置中的默认倍率计算费用 // CalculateCostWithConfig 使用配置中的默认倍率计算费用
...@@ -831,9 +802,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag ...@@ -831,9 +802,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
// 计算总费用 // 计算总费用
totalCost := unitPrice * float64(imageCount) totalCost := unitPrice * float64(imageCount)
// 应用倍率 // 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣)
if rateMultiplier <= 0 { if rateMultiplier < 0 {
rateMultiplier = 1.0 rateMultiplier = 0
} }
actualCost := totalCost * rateMultiplier actualCost := totalCost * rateMultiplier
......
...@@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) { ...@@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) {
require.Equal(t, 0.0, cost.ActualCost) require.Equal(t, 0.0, cost.ActualCost)
} }
// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0 // TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费
// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) { func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
svc := &BillingService{} svc := &BillingService{}
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0) cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
} }
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格 // TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
......
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被
// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。
func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
tests := []struct {
name string
multiplier float64
wantRatio float64 // ActualCost / TotalCost
}{
{"negative clamped to 0", -1.5, 0},
{"zero passes through as 0 (defense in depth)", 0, 0},
{"positive 2x applied", 2.0, 2.0},
{"positive 0.5x applied", 0.5, 0.5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier)
require.NoError(t, err)
require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero")
require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
})
}
}
// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径
// 同样遵循"负数 → 0"语义。
func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) {
svc := newTestBillingService()
price := 0.04
cfg := &ImagePriceConfig{Price1K: &price}
tests := []struct {
name string
multiplier float64
wantRatio float64
}{
{"negative clamped to 0", -0.5, 0},
{"zero passes through", 0, 0},
{"positive 3x applied", 3.0, 3.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier)
require.NotNil(t, cost)
require.Greater(t, cost.TotalCost, 0.0)
require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9)
})
}
}
...@@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) { ...@@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) {
require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10) require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10)
} }
func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000}
costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0)
require.NoError(t, err)
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
}
func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000}
costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0)
require.NoError(t, err)
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
}
func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) { func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) {
svc := newTestBillingService() svc := newTestBillingService()
...@@ -151,15 +123,6 @@ func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) { ...@@ -151,15 +123,6 @@ func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) {
require.Contains(t, err.Error(), "pricing not found") require.Contains(t, err.Error(), "pricing not found")
} }
func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricing("gpt-5.1")
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12)
}
func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) { func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
svc := newTestBillingService() svc := newTestBillingService()
...@@ -186,18 +149,6 @@ func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) { ...@@ -186,18 +149,6 @@ func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) {
require.Zero(t, pricing.LongContextInputThreshold) require.Zero(t, pricing.LongContextInputThreshold)
} }
func TestGetModelPricing_OpenAIGPT54NanoFallback(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricing("gpt-5.4-nano")
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, 2e-7, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 1.25e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 2e-8, pricing.CacheReadPricePerToken, 1e-12)
require.Zero(t, pricing.LongContextInputThreshold)
}
func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) { func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) {
svc := newTestBillingService() svc := newTestBillingService()
...@@ -232,13 +183,13 @@ func TestGetFallbackPricing_FamilyMatching(t *testing.T) { ...@@ -232,13 +183,13 @@ func TestGetFallbackPricing_FamilyMatching(t *testing.T) {
{name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6}, {name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6},
{name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6}, {name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6},
{name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true}, {name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true},
{name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6},
{name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6}, {name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6},
{name: "openai gpt5.4 mini", model: "gpt-5.4-mini", expectedInput: 7.5e-7}, {name: "openai gpt5.4 mini", model: "gpt-5.4-mini", expectedInput: 7.5e-7},
{name: "openai gpt5.4 nano", model: "gpt-5.4-nano", expectedInput: 2e-7},
{name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6}, {name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6},
{name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6}, {name: "openai gpt5.3 codex spark", model: "gpt-5.3-codex-spark", expectedInput: 1.5e-6},
{name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6}, {name: "openai legacy gpt5.1 falls back to gpt5.4", model: "gpt-5.1", expectedInput: 2.5e-6},
{name: "openai legacy gpt5.1 codex falls back to gpt5.3 codex", model: "gpt-5.1-codex", expectedInput: 1.5e-6},
{name: "openai legacy codex mini latest falls back to gpt5.3 codex", model: "codex-mini-latest", expectedInput: 1.5e-6},
{name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true}, {name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true},
{name: "non supported family", model: "qwen-max", expectNilPricing: true}, {name: "non supported family", model: "qwen-max", expectNilPricing: true},
} }
......
...@@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) { ...@@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) {
require.Equal(t, string(BillingModeImage), cost.BillingMode) require.Equal(t, string(BillingModeImage), cost.BillingMode)
} }
func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) { // TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为:
// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。
func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) {
bs := newTestBillingService() bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs) resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
costZero, err := bs.CalculateCostUnified(CostInput{ cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 0, // should default to 1.0
Resolver: resolver,
})
require.NoError(t, err)
costOne, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(), Ctx: context.Background(),
Model: "claude-sonnet-4", Model: "claude-sonnet-4",
Tokens: tokens, Tokens: tokens,
RateMultiplier: 1.0, RateMultiplier: 0,
Resolver: resolver, Resolver: resolver,
}) })
require.NoError(t, err) require.NoError(t, err)
require.Greater(t, cost.TotalCost, 0.0)
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
} }
func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) { // TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为:
// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。
func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) {
bs := newTestBillingService() bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs) resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000} tokens := UsageTokens{InputTokens: 1000}
costNeg, err := bs.CalculateCostUnified(CostInput{ cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(), Ctx: context.Background(),
Model: "claude-sonnet-4", Model: "claude-sonnet-4",
Tokens: tokens, Tokens: tokens,
...@@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) ...@@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T)
Resolver: resolver, Resolver: resolver,
}) })
require.NoError(t, err) require.NoError(t, err)
require.Greater(t, cost.TotalCost, 0.0)
costOne, err := bs.CalculateCostUnified(CostInput{ require.InDelta(t, 0.0, cost.ActualCost, 1e-10)
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.0,
Resolver: resolver,
})
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
} }
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) { func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
......
...@@ -962,7 +962,7 @@ func NormalizeClaudeOutputEffort(raw string) *string { ...@@ -962,7 +962,7 @@ func NormalizeClaudeOutputEffort(raw string) *string {
return nil return nil
} }
switch value { switch value {
case "low", "medium", "high", "max": case "low", "medium", "high", "xhigh", "max":
return &value return &value
default: default:
return nil return nil
......
...@@ -1149,6 +1149,11 @@ func TestParseGatewayRequest_OutputEffort(t *testing.T) { ...@@ -1149,6 +1149,11 @@ func TestParseGatewayRequest_OutputEffort(t *testing.T) {
body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`, body: `{"model":"claude-opus-4-6","output_config":{"effort":"max"},"messages":[]}`,
wantEffort: "max", wantEffort: "max",
}, },
{
name: "output_config.effort xhigh",
body: `{"model":"claude-opus-4-7","output_config":{"effort":"xhigh"},"messages":[]}`,
wantEffort: "xhigh",
},
{ {
name: "output_config without effort", name: "output_config without effort",
body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`, body: `{"model":"claude-opus-4-6","output_config":{},"messages":[]}`,
...@@ -1186,9 +1191,10 @@ func TestNormalizeClaudeOutputEffort(t *testing.T) { ...@@ -1186,9 +1191,10 @@ func TestNormalizeClaudeOutputEffort(t *testing.T) {
{"LOW", strPtr("low")}, {"LOW", strPtr("low")},
{"Max", strPtr("max")}, {"Max", strPtr("max")},
{" medium ", strPtr("medium")}, {" medium ", strPtr("medium")},
{"xhigh", strPtr("xhigh")},
{"XHIGH", strPtr("xhigh")},
{"", nil}, {"", nil},
{"unknown", nil}, {"unknown", nil},
{"xhigh", nil},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) { t.Run(tt.input, func(t *testing.T) {
......
...@@ -435,26 +435,19 @@ func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) i ...@@ -435,26 +435,19 @@ func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) i
} }
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 // shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间, // 委托 IsSchedulable() 判断账号级可调度性(状态、配额、过载、限流等),
// 或请求的模型处于限流状态时,返回 true。 // 额外检查模型级限流。
// 这确保后续请求不会继续使用不可用的账号。
// //
// shouldClearStickySession checks if an account is in an unschedulable state // shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared. // and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false, // Delegates to IsSchedulable() for account-level checks, plus model-level rate limiting.
// within temporary unschedulable period, or the requested model is rate-limited.
// This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account, requestedModel string) bool { func shouldClearStickySession(account *Account, requestedModel string) bool {
if account == nil { if account == nil {
return false return false
} }
if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable { if !account.IsSchedulable() {
return true return true
} }
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true
}
// 检查模型限流和 scope 限流,有限流即清除粘性会话
if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 { if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 {
return true return true
} }
...@@ -7317,8 +7310,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill ...@@ -7317,8 +7310,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
cost := p.Cost cost := p.Cost
if p.IsSubscriptionBill { if p.IsSubscriptionBill {
if cost.TotalCost > 0 { // Subscription usage tracked by ActualCost so group rate multiplier
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { // consumes the quota at the expected speed.
if cost.ActualCost > 0 {
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
} }
} }
...@@ -7417,9 +7412,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage ...@@ -7417,9 +7412,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
} }
} }
// Record subscription / balance cost using ActualCost so the group (and any
// user-specific) rate multiplier consumes subscription quota at the expected
// speed. TotalCost remains the raw (pre-multiplier) value; downstream guards
// on "> 0" still correctly skip free subscriptions (RateMultiplier == 0).
if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 {
cmd.SubscriptionID = &p.Subscription.ID cmd.SubscriptionID = &p.Subscription.ID
cmd.SubscriptionCost = p.Cost.TotalCost cmd.SubscriptionCost = p.Cost.ActualCost
} else if p.Cost.ActualCost > 0 { } else if p.Cost.ActualCost > 0 {
cmd.BalanceCost = p.Cost.ActualCost cmd.BalanceCost = p.Cost.ActualCost
} }
...@@ -7478,8 +7477,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu ...@@ -7478,8 +7477,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu
} }
if p.IsSubscriptionBill { if p.IsSubscriptionBill {
if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil {
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost) deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost)
} }
} else if p.Cost.ActualCost > 0 && p.User != nil { } else if p.Cost.ActualCost > 0 && p.User != nil {
deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost)
......
//go:build unit
package service
import (
"testing"
)
// TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier locks in the fix
// that subscription-mode billing honours the group (and any user-specific) rate
// multiplier — i.e. cmd.SubscriptionCost tracks ActualCost (= TotalCost *
// RateMultiplier), not raw TotalCost.
func TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier(t *testing.T) {
t.Parallel()
groupID := int64(7)
subID := int64(42)
tests := []struct {
name string
totalCost float64
actualCost float64
isSubscription bool
wantSub float64
wantBalance float64
}{
{
name: "subscription with 2x multiplier consumes 2x quota",
totalCost: 1.0,
actualCost: 2.0,
isSubscription: true,
wantSub: 2.0,
wantBalance: 0,
},
{
name: "subscription with 0.5x multiplier consumes 0.5x quota",
totalCost: 1.0,
actualCost: 0.5,
isSubscription: true,
wantSub: 0.5,
wantBalance: 0,
},
{
name: "free subscription (multiplier 0) consumes no quota",
totalCost: 1.0,
actualCost: 0,
isSubscription: true,
wantSub: 0,
wantBalance: 0,
},
{
name: "balance billing keeps using ActualCost (regression)",
totalCost: 1.0,
actualCost: 2.0,
isSubscription: false,
wantSub: 0,
wantBalance: 2.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
p := &postUsageBillingParams{
Cost: &CostBreakdown{TotalCost: tt.totalCost, ActualCost: tt.actualCost},
User: &User{ID: 1},
APIKey: &APIKey{ID: 2, GroupID: &groupID},
Account: &Account{ID: 3},
Subscription: &UserSubscription{ID: subID},
IsSubscriptionBill: tt.isSubscription,
}
cmd := buildUsageBillingCommand("req-1", nil, p)
if cmd == nil {
t.Fatal("buildUsageBillingCommand returned nil")
}
if cmd.SubscriptionCost != tt.wantSub {
t.Errorf("SubscriptionCost = %v, want %v", cmd.SubscriptionCost, tt.wantSub)
}
if cmd.BalanceCost != tt.wantBalance {
t.Errorf("BalanceCost = %v, want %v", cmd.BalanceCost, tt.wantBalance)
}
})
}
}
...@@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool { ...@@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription return g.SubscriptionType == SubscriptionTypeSubscription
} }
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
func (g *Group) HasDailyLimit() bool { func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0 return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
} }
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
var codexModelMap = map[string]string{ var codexModelMap = map[string]string{
"gpt-5.4": "gpt-5.4", "gpt-5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini", "gpt-5.4-mini": "gpt-5.4-mini",
"gpt-5.4-nano": "gpt-5.4-nano",
"gpt-5.4-none": "gpt-5.4", "gpt-5.4-none": "gpt-5.4",
"gpt-5.4-low": "gpt-5.4", "gpt-5.4-low": "gpt-5.4",
"gpt-5.4-medium": "gpt-5.4", "gpt-5.4-medium": "gpt-5.4",
...@@ -22,52 +21,21 @@ var codexModelMap = map[string]string{ ...@@ -22,52 +21,21 @@ var codexModelMap = map[string]string{
"gpt-5.3-high": "gpt-5.3-codex", "gpt-5.3-high": "gpt-5.3-codex",
"gpt-5.3-xhigh": "gpt-5.3-codex", "gpt-5.3-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-spark": "gpt-5.3-codex", "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-low": "gpt-5.3-codex", "gpt-5.3-codex-spark-low": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-medium": "gpt-5.3-codex", "gpt-5.3-codex-spark-medium": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex", "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt-5.3-codex-low": "gpt-5.3-codex", "gpt-5.3-codex-low": "gpt-5.3-codex",
"gpt-5.3-codex-medium": "gpt-5.3-codex", "gpt-5.3-codex-medium": "gpt-5.3-codex",
"gpt-5.3-codex-high": "gpt-5.3-codex", "gpt-5.3-codex-high": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt-5.1-codex": "gpt-5.1-codex",
"gpt-5.1-codex-low": "gpt-5.1-codex",
"gpt-5.1-codex-medium": "gpt-5.1-codex",
"gpt-5.1-codex-high": "gpt-5.1-codex",
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
"gpt-5.2": "gpt-5.2", "gpt-5.2": "gpt-5.2",
"gpt-5.2-none": "gpt-5.2", "gpt-5.2-none": "gpt-5.2",
"gpt-5.2-low": "gpt-5.2", "gpt-5.2-low": "gpt-5.2",
"gpt-5.2-medium": "gpt-5.2", "gpt-5.2-medium": "gpt-5.2",
"gpt-5.2-high": "gpt-5.2", "gpt-5.2-high": "gpt-5.2",
"gpt-5.2-xhigh": "gpt-5.2", "gpt-5.2-xhigh": "gpt-5.2",
"gpt-5.2-codex": "gpt-5.2-codex",
"gpt-5.2-codex-low": "gpt-5.2-codex",
"gpt-5.2-codex-medium": "gpt-5.2-codex",
"gpt-5.2-codex-high": "gpt-5.2-codex",
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5.1": "gpt-5.1",
"gpt-5.1-none": "gpt-5.1",
"gpt-5.1-low": "gpt-5.1",
"gpt-5.1-medium": "gpt-5.1",
"gpt-5.1-high": "gpt-5.1",
"gpt-5.1-chat-latest": "gpt-5.1",
"gpt-5-codex": "gpt-5.1-codex",
"codex-mini-latest": "gpt-5.1-codex-mini",
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5": "gpt-5.1",
"gpt-5-mini": "gpt-5.1",
"gpt-5-nano": "gpt-5.1",
} }
type codexTransformResult struct { type codexTransformResult struct {
...@@ -220,7 +188,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact ...@@ -220,7 +188,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact
func normalizeCodexModel(model string) string { func normalizeCodexModel(model string) string {
if model == "" { if model == "" {
return "gpt-5.1" return "gpt-5.4"
} }
modelID := model modelID := model
...@@ -238,49 +206,29 @@ func normalizeCodexModel(model string) string { ...@@ -238,49 +206,29 @@ func normalizeCodexModel(model string) string {
if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") { if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") {
return "gpt-5.4-mini" return "gpt-5.4-mini"
} }
if strings.Contains(normalized, "gpt-5.4-nano") || strings.Contains(normalized, "gpt 5.4 nano") {
return "gpt-5.4-nano"
}
if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") { if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") {
return "gpt-5.4" return "gpt-5.4"
} }
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
return "gpt-5.2-codex"
}
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") { if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
return "gpt-5.2" return "gpt-5.2"
} }
if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") {
return "gpt-5.3-codex-spark"
}
if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") { if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") {
return "gpt-5.3-codex" return "gpt-5.3-codex"
} }
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") { if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
return "gpt-5.3-codex" return "gpt-5.3-codex"
} }
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
return "gpt-5.1-codex-max"
}
if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
return "gpt-5.1-codex-mini"
}
if strings.Contains(normalized, "codex-mini-latest") ||
strings.Contains(normalized, "gpt-5-codex-mini") ||
strings.Contains(normalized, "gpt 5 codex mini") {
return "codex-mini-latest"
}
if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
return "gpt-5.1-codex"
}
if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
return "gpt-5.1"
}
if strings.Contains(normalized, "codex") { if strings.Contains(normalized, "codex") {
return "gpt-5.1-codex" return "gpt-5.3-codex"
} }
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") { if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
return "gpt-5.1" return "gpt-5.4"
} }
return "gpt-5.1" return "gpt-5.4"
} }
func normalizeOpenAIModelForUpstream(account *Account, model string) string { func normalizeOpenAIModelForUpstream(account *Account, model string) string {
......
...@@ -240,15 +240,13 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { ...@@ -240,15 +240,13 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
"gpt 5.4": "gpt-5.4", "gpt 5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini", "gpt-5.4-mini": "gpt-5.4-mini",
"gpt 5.4 mini": "gpt-5.4-mini", "gpt 5.4 mini": "gpt-5.4-mini",
"gpt-5.4-nano": "gpt-5.4-nano",
"gpt 5.4 nano": "gpt-5.4-nano",
"gpt-5.3": "gpt-5.3-codex", "gpt-5.3": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex-spark": "gpt-5.3-codex", "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
"gpt 5.3 codex spark": "gpt-5.3-codex", "gpt 5.3 codex spark": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex", "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt 5.3 codex": "gpt-5.3-codex", "gpt 5.3 codex": "gpt-5.3-codex",
} }
...@@ -257,6 +255,26 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { ...@@ -257,6 +255,26 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
} }
} }
func TestNormalizeCodexModel_RemovedModelsFallbackToSupportedTargets(t *testing.T) {
cases := map[string]string{
"": "gpt-5.4",
"gpt-5": "gpt-5.4",
"gpt-5-mini": "gpt-5.4",
"gpt-5-nano": "gpt-5.4",
"gpt-5.1": "gpt-5.4",
"gpt-5.1-codex": "gpt-5.3-codex",
"gpt-5.1-codex-max": "gpt-5.3-codex",
"gpt-5.1-codex-mini": "gpt-5.3-codex",
"gpt-5.2-codex": "gpt-5.2",
"codex-mini-latest": "gpt-5.3-codex",
"gpt-5-codex": "gpt-5.3-codex",
}
for input, expected := range cases {
require.Equal(t, expected, normalizeCodexModel(input))
}
}
func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) { func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) {
reqBody := map[string]any{ reqBody := map[string]any{
"model": "gpt-5.3-codex-spark", "model": "gpt-5.3-codex-spark",
......
...@@ -10,8 +10,14 @@ import ( ...@@ -10,8 +10,14 @@ import (
const compatPromptCacheKeyPrefix = "compat_cc_" const compatPromptCacheKeyPrefix = "compat_cc_"
func shouldAutoInjectPromptCacheKeyForCompat(model string) bool { func shouldAutoInjectPromptCacheKeyForCompat(model string) bool {
switch normalizeCodexModel(strings.TrimSpace(model)) { trimmed := strings.TrimSpace(strings.ToLower(model))
case "gpt-5.4", "gpt-5.3-codex": // 仅对 Codex OAuth 路径支持的 GPT-5 族开启自动注入,避免 normalizeCodexModel
// 的默认兜底把任意模型(如 gpt-4o、claude-*)误判为 gpt-5.4。
if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
return false
}
switch normalizeCodexModel(trimmed) {
case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark":
return true return true
default: default:
return false return false
......
...@@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel ...@@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel
Model: "gpt-5.1", Model: "gpt-5.1",
Duration: time.Second, Duration: time.Second,
}, },
APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}}, APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}},
User: &User{ID: 200}, User: &User{ID: 200},
Account: &Account{ID: 300}, Account: &Account{ID: 300},
Subscription: subscription, Subscription: subscription,
......
...@@ -69,14 +69,14 @@ func TestResolveOpenAIForwardModel(t *testing.T) { ...@@ -69,14 +69,14 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
} }
} }
func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) { func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *testing.T) {
account := &Account{ account := &Account{
Credentials: map[string]any{}, Credentials: map[string]any{},
} }
withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", ""))
if withoutDefault != "gpt-5.1" { if withoutDefault != "gpt-5.4" {
t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1") t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4")
} }
withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4"))
...@@ -87,9 +87,9 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t * ...@@ -87,9 +87,9 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *
func TestNormalizeCodexModel(t *testing.T) { func TestNormalizeCodexModel(t *testing.T) {
cases := map[string]string{ cases := map[string]string{
"gpt-5.3-codex-spark": "gpt-5.3-codex", "gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex", "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
"gpt-5.3": "gpt-5.3-codex", "gpt-5.3": "gpt-5.3-codex",
} }
...@@ -111,7 +111,7 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) { ...@@ -111,7 +111,7 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) {
name: "oauth keeps codex normalization behavior", name: "oauth keeps codex normalization behavior",
account: &Account{Type: AccountTypeOAuth}, account: &Account{Type: AccountTypeOAuth},
model: "gemini-3-flash-preview", model: "gemini-3-flash-preview",
want: "gpt-5.1", want: "gpt-5.4",
}, },
{ {
name: "apikey preserves custom compatible model", name: "apikey preserves custom compatible model",
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"strconv" "strconv"
"strings" "strings"
...@@ -11,9 +12,22 @@ import ( ...@@ -11,9 +12,22 @@ import (
"github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
"github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
// validateProviderConfig runs the provider's constructor to surface config-level
// errors at save time (e.g. wxpay missing certSerial), instead of only failing
// when an order is created. Returns the structured ApplicationError from the
// constructor so the frontend i18n layer can localize it.
//
// Only validates enabled instances — a disabled instance may be a half-filled
// draft the admin will complete later.
func (s *PaymentConfigService) validateProviderConfig(providerKey string, config map[string]string) error {
_, err := provider.CreateProvider(providerKey, "_validate_", config)
return err
}
// --- Provider Instance CRUD --- // --- Provider Instance CRUD ---
func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) { func (s *PaymentConfigService) ListProviderInstances(ctx context.Context) ([]*dbent.PaymentProviderInstance, error) {
...@@ -47,11 +61,10 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte ...@@ -47,11 +61,10 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
resp := ProviderInstanceResponse{ resp := ProviderInstanceResponse{
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, AllowUserRefund: inst.AllowUserRefund,
AllowUserRefund: inst.AllowUserRefund,
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
} }
resp.Config, err = s.decryptAndMaskConfig(inst.Config) resp.Config, err = s.decryptAndMaskConfig(inst.ProviderKey, inst.Config)
if err != nil { if err != nil {
return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err) return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err)
} }
...@@ -60,8 +73,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte ...@@ -60,8 +73,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
return result, nil return result, nil
} }
func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) { // decryptAndMaskConfig returns the stored config with sensitive fields omitted.
return s.decryptConfig(encrypted) // Admin UIs display masked placeholders for these; the raw values never leave
// the server. Callers that need the full config (e.g. payment runtime) must
// use decryptConfig directly.
func (s *PaymentConfigService) decryptAndMaskConfig(providerKey, encrypted string) (map[string]string, error) {
cfg, err := s.decryptConfig(encrypted)
if err != nil {
return nil, err
}
if cfg == nil {
return nil, nil
}
masked := make(map[string]string, len(cfg))
for k, v := range cfg {
if isSensitiveProviderConfigField(providerKey, k) {
continue
}
masked[k] = v
}
return masked, nil
} }
// pendingOrderStatuses are order statuses considered "in progress". // pendingOrderStatuses are order statuses considered "in progress".
...@@ -71,16 +102,27 @@ var pendingOrderStatuses = []string{ ...@@ -71,16 +102,27 @@ var pendingOrderStatuses = []string{
payment.OrderStatusRecharging, payment.OrderStatusRecharging,
} }
var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"} // providerSensitiveConfigFields is the authoritative list of config keys that
// are treated as secrets per provider. Must stay in sync with the frontend
// definition at frontend/src/components/payment/providerConfig.ts
// (PROVIDER_CONFIG_FIELDS, fields with sensitive: true).
//
// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl,
// stripe publishableKey) are returned in plaintext by the admin GET API.
var providerSensitiveConfigFields = map[string]map[string]struct{}{
payment.TypeEasyPay: {"pkey": {}},
payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}},
payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}},
payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}},
}
func isSensitiveConfigField(fieldName string) bool { func isSensitiveProviderConfigField(providerKey, fieldName string) bool {
lower := strings.ToLower(fieldName) fields, ok := providerSensitiveConfigFields[providerKey]
for _, p := range sensitiveConfigPatterns { if !ok {
if strings.Contains(lower, p) {
return true
}
}
return false return false
}
_, found := fields[strings.ToLower(fieldName)]
return found
} }
func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) { func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) {
...@@ -111,6 +153,11 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C ...@@ -111,6 +153,11 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
if err := s.validateVisibleMethodEnablementConflicts(ctx, 0, req.ProviderKey, typesStr, req.Enabled); err != nil { if err := s.validateVisibleMethodEnablementConflicts(ctx, 0, req.ProviderKey, typesStr, req.Enabled); err != nil {
return nil, err return nil, err
} }
if req.Enabled {
if err := s.validateProviderConfig(req.ProviderKey, req.Config); err != nil {
return nil, err
}
}
enc, err := s.encryptConfig(req.Config) enc, err := s.encryptConfig(req.Config)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -141,7 +188,7 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error { ...@@ -141,7 +188,7 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error {
func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) { func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) {
current, err := s.entClient.PaymentProviderInstance.Get(ctx, id) current, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("load provider instance: %w", err)
} }
nextEnabled := current.Enabled nextEnabled := current.Enabled
if req.Enabled != nil { if req.Enabled != nil {
...@@ -156,8 +203,8 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in ...@@ -156,8 +203,8 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
} }
if req.Config != nil { if req.Config != nil {
hasSensitive := false hasSensitive := false
for k := range req.Config { for k, v := range req.Config {
if isSensitiveConfigField(k) && req.Config[k] != "" { if v != "" && isSensitiveProviderConfigField(current.ProviderKey, k) {
hasSensitive = true hasSensitive = true
break break
} }
...@@ -183,16 +230,38 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in ...@@ -183,16 +230,38 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
WithMetadata(map[string]string{"count": strconv.Itoa(count)}) WithMetadata(map[string]string{"count": strconv.Itoa(count)})
} }
} }
u := s.entClient.PaymentProviderInstance.UpdateOneID(id) // Validate merged config when the instance will end up enabled.
if req.Name != nil { // This surfaces provider-level errors (e.g. wxpay missing certSerial) at save time,
u.SetName(*req.Name) // so admins see them in the dialog instead of only when an order is created.
finalEnabled := current.Enabled
if req.Enabled != nil {
finalEnabled = *req.Enabled
} }
var mergedConfig map[string]string
if req.Config != nil { if req.Config != nil {
merged, err := s.mergeConfig(ctx, id, req.Config) mergedConfig, err = s.mergeConfig(ctx, id, req.Config)
if err != nil {
return nil, err
}
}
if finalEnabled {
configToValidate := mergedConfig
if configToValidate == nil {
configToValidate, err = s.decryptConfig(current.Config)
if err != nil { if err != nil {
return nil, fmt.Errorf("decrypt existing config: %w", err)
}
}
if err := s.validateProviderConfig(current.ProviderKey, configToValidate); err != nil {
return nil, err return nil, err
} }
enc, err := s.encryptConfig(merged) }
u := s.entClient.PaymentProviderInstance.UpdateOneID(id)
if req.Name != nil {
u.SetName(*req.Name)
}
if mergedConfig != nil {
enc, err := s.encryptConfig(mergedConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -293,27 +362,48 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon ...@@ -293,27 +362,48 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon
return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err) return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err)
} }
if existing == nil { if existing == nil {
return newConfig, nil existing = map[string]string{}
} }
for k, v := range newConfig { for k, v := range newConfig {
// Preserve existing secrets when the client submits an empty value
// (admin UI omits the value to indicate "leave unchanged").
if v == "" && isSensitiveProviderConfigField(inst.ProviderKey, k) {
continue
}
existing[k] = v existing[k] = v
} }
return existing, nil return existing, nil
} }
func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) { // decryptConfig parses a stored provider config.
if encrypted == "" { // New records are plaintext JSON; legacy records are AES-256-GCM ciphertext
// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including
// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty,
// letting the admin re-enter the config via the UI to complete the migration.
//
// TODO(deprecated-legacy-ciphertext): The AES fallback branch is a transitional
// shim for pre-plaintext records. Remove it (and the encryptionKey field) after
// a few releases once all live deployments have re-saved their provider configs.
func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) {
if stored == "" {
return nil, nil return nil, nil
} }
decrypted, err := payment.Decrypt(encrypted, s.encryptionKey) var cfg map[string]string
if err != nil { if err := json.Unmarshal([]byte(stored), &cfg); err == nil {
return nil, fmt.Errorf("decrypt config: %w", err) return cfg, nil
} }
var raw map[string]string // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
if err := json.Unmarshal([]byte(decrypted), &raw); err != nil { if len(s.encryptionKey) == payment.AES256KeySize {
return nil, fmt.Errorf("unmarshal decrypted config: %w", err) //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil {
if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil {
return cfg, nil
} }
return raw, nil }
}
slog.Warn("payment provider config unreadable, treating as empty for re-entry",
"stored_len", len(stored))
return nil, nil
} }
func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error { func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error {
...@@ -328,14 +418,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in ...@@ -328,14 +418,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in
return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx) return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx)
} }
// encryptConfig serialises a provider config for storage.
// New records are written as plaintext JSON; the historical AES-GCM wrapping
// has been dropped but decryptConfig still accepts old ciphertext during migration.
func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) { func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) {
data, err := json.Marshal(cfg) data, err := json.Marshal(cfg)
if err != nil { if err != nil {
return "", fmt.Errorf("marshal config: %w", err) return "", fmt.Errorf("marshal config: %w", err)
} }
enc, err := payment.Encrypt(string(data), s.encryptionKey) return string(data), nil
if err != nil {
return "", fmt.Errorf("encrypt config: %w", err)
}
return enc, nil
} }
...@@ -99,41 +99,52 @@ func TestValidateProviderRequest(t *testing.T) { ...@@ -99,41 +99,52 @@ func TestValidateProviderRequest(t *testing.T) {
} }
} }
func TestIsSensitiveConfigField(t *testing.T) { func TestIsSensitiveProviderConfigField(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
providerKey string
field string field string
wantSen bool wantSen bool
}{ }{
// Sensitive fields (contain key/secret/private/password/pkey patterns) // Stripe: publishableKey is public, only secretKey/webhookSecret are secrets
{"secretKey", true}, {"stripe", "secretKey", true},
{"apiSecret", true}, {"stripe", "webhookSecret", true},
{"pkey", true}, {"stripe", "SecretKey", true}, // case-insensitive
{"privateKey", true}, {"stripe", "publishableKey", false},
{"apiPassword", true}, {"stripe", "appId", false},
{"appKey", true},
{"SECRET_TOKEN", true}, // Alipay
{"PrivateData", true}, {"alipay", "privateKey", true},
{"PASSWORD", true}, {"alipay", "publicKey", true},
{"mySecretValue", true}, {"alipay", "alipayPublicKey", true},
{"alipay", "appId", false},
// Non-sensitive fields {"alipay", "notifyUrl", false},
{"appId", false},
{"mchId", false}, // Wxpay
{"apiBase", false}, {"wxpay", "privateKey", true},
{"endpoint", false}, {"wxpay", "apiV3Key", true},
{"merchantNo", false}, {"wxpay", "publicKey", true},
{"paymentMode", false}, {"wxpay", "publicKeyId", false},
{"notifyUrl", false}, {"wxpay", "certSerial", false},
{"wxpay", "mchId", false},
// EasyPay
{"easypay", "pkey", true},
{"easypay", "pid", false},
{"easypay", "apiBase", false},
// Unknown provider: never sensitive
{"unknown", "secretKey", false},
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.field, func(t *testing.T) { tc := tc
t.Run(tc.providerKey+"/"+tc.field, func(t *testing.T) {
t.Parallel() t.Parallel()
got := isSensitiveConfigField(tc.field) got := isSensitiveProviderConfigField(tc.providerKey, tc.field)
assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field) assert.Equal(t, tc.wantSen, got, "isSensitiveProviderConfigField(%q, %q)", tc.providerKey, tc.field)
}) })
} }
} }
......
...@@ -2,6 +2,7 @@ package service ...@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
...@@ -201,7 +202,7 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us ...@@ -201,7 +202,7 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return fmt.Errorf("count pending orders: %w", err) return fmt.Errorf("count pending orders: %w", err)
} }
if c >= max { if c >= max {
return infraerrors.TooManyRequests("TOO_MANY_PENDING", fmt.Sprintf("too many pending orders (max %d)", max)). return infraerrors.TooManyRequests("TOO_MANY_PENDING", "too_many_pending").
WithMetadata(map[string]string{"max": strconv.Itoa(max)}) WithMetadata(map[string]string{"max": strconv.Itoa(max)})
} }
return nil return nil
...@@ -284,7 +285,8 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user ...@@ -284,7 +285,8 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
used += o.Amount used += o.Amount
} }
if used+amount > limit { if used+amount > limit {
return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", fmt.Sprintf("daily recharge limit reached, remaining: %.2f", math.Max(0, limit-used))) return infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily_limit_exceeded").
WithMetadata(map[string]string{"remaining": fmt.Sprintf("%.2f", math.Max(0, limit-used))})
} }
return nil return nil
} }
...@@ -296,10 +298,11 @@ func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req Crea ...@@ -296,10 +298,11 @@ func (s *PaymentService) selectCreateOrderInstance(ctx context.Context, req Crea
} }
sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) sel, err := s.loadBalancer.SelectInstance(selectCtx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
if err != nil { if err != nil {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType)) return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "method_not_configured").
WithMetadata(map[string]string{"payment_type": req.PaymentType})
} }
if sel == nil { if sel == nil {
return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance") return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no_available_instance")
} }
return sel, nil return sel, nil
} }
...@@ -342,7 +345,18 @@ func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) boo ...@@ -342,7 +345,18 @@ func (s *PaymentService) usesOfficialWxpayVisibleMethod(ctx context.Context) boo
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) { func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, limitAmount float64, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan, sel *payment.InstanceSelection) (*CreateOrderResponse, error) {
prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config) prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config)
if err != nil { if err != nil {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable") slog.Error("[PaymentService] CreateProvider failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
// If the provider returned a structured ApplicationError (e.g. WXPAY_CONFIG_MISSING_KEY),
// pass it through with provider context added to metadata. Otherwise wrap as PAYMENT_PROVIDER_MISCONFIGURED.
if appErr := new(infraerrors.ApplicationError); errors.As(err, &appErr) {
md := map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID}
for k, v := range appErr.Metadata {
md[k] = v
}
return nil, appErr.WithMetadata(md)
}
return nil, infraerrors.ServiceUnavailable("PAYMENT_PROVIDER_MISCONFIGURED", "provider_misconfigured").
WithMetadata(map[string]string{"provider": sel.ProviderKey, "instance_id": sel.InstanceID})
} }
subject := s.buildPaymentSubject(plan, limitAmount, cfg) subject := s.buildPaymentSubject(plan, limitAmount, cfg)
outTradeNo := order.OutTradeNo outTradeNo := order.OutTradeNo
...@@ -380,6 +394,9 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen ...@@ -380,6 +394,9 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
pr, err := prov.CreatePayment(ctx, providerReq) pr, err := prov.CreatePayment(ctx, providerReq)
if err != nil { if err != nil {
slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err) slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
if appErr := new(infraerrors.ApplicationError); errors.As(err, &appErr) {
return nil, appErr
}
return nil, classifyCreatePaymentError(req, sel.ProviderKey, err) return nil, classifyCreatePaymentError(req, sel.ProviderKey, err)
} }
_, err = s.entClient.PaymentOrder.UpdateOneID(order.ID). _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).
......
...@@ -15,20 +15,8 @@ import ( ...@@ -15,20 +15,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// TestShouldClearStickySession 测试粘性会话清理判断逻辑。 // TestShouldClearStickySession tests sticky session clearing via IsSchedulable() delegation
// 验证在以下情况下是否正确判断需要清理粘性会话: // plus model-level rate limiting.
// - nil 账号:不清理(返回 false)
// - 状态为错误或禁用:清理
// - 不可调度:清理
// - 临时不可调度且未过期:清理
// - 临时不可调度已过期:不清理
// - 正常可调度状态:不清理
// - 模型限流(任意时长):清理
//
// TestShouldClearStickySession tests the sticky session clearing logic.
// Verifies correct behavior for various account states including:
// nil account, error/disabled status, unschedulable, temporary unschedulable,
// and model rate limiting scenarios.
func TestShouldClearStickySession(t *testing.T) { func TestShouldClearStickySession(t *testing.T) {
now := time.Now() now := time.Now()
future := now.Add(1 * time.Hour) future := now.Add(1 * time.Hour)
...@@ -101,6 +89,56 @@ func TestShouldClearStickySession(t *testing.T) { ...@@ -101,6 +89,56 @@ func TestShouldClearStickySession(t *testing.T) {
requestedModel: "claude-opus-4", // 请求不同模型 requestedModel: "claude-opus-4", // 请求不同模型
want: false, // 不同模型不受影响 want: false, // 不同模型不受影响
}, },
{
name: "apikey quota exceeded",
account: &Account{
Status: StatusActive,
Schedulable: true,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"quota_daily_limit": 10.0,
"quota_daily_used": 10.0,
"quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
},
},
requestedModel: "",
want: true,
},
{
name: "oauth quota exceeded not cleared",
account: &Account{
Status: StatusActive,
Schedulable: true,
Type: AccountTypeOAuth,
Extra: map[string]any{
"quota_daily_limit": 10.0,
"quota_daily_used": 10.0,
"quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
},
},
requestedModel: "",
want: false,
},
{
name: "overloaded account",
account: &Account{
Status: StatusActive,
Schedulable: true,
OverloadUntil: &future,
},
requestedModel: "",
want: true,
},
{
name: "account-level rate limited",
account: &Account{
Status: StatusActive,
Schedulable: true,
RateLimitResetAt: &future,
},
requestedModel: "",
want: true,
},
} }
for _, tt := range tests { for _, tt := range tests {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment