Commit 2555951b authored by erio's avatar erio
Browse files

feat(channel): 渠道管理全链路集成 — 模型映射、定价、限制、用量统计

- 渠道模型映射:支持精确匹配和通配符映射,按平台隔离
- 渠道模型定价:支持 token/按次/图片三种计费模式,区间分层定价
- 模型限制:渠道可限制仅允许定价列表中的模型
- 计费模型来源:支持 requested/upstream 两种计费模型选择
- 用量统计:usage_logs 新增 channel_id/model_mapping_chain/billing_tier/billing_mode 字段
- Dashboard 支持 model_source 维度(requested/upstream/mapping)查看模型统计
- 全部 gateway handler 统一接入 ResolveChannelMappingAndRestrict
- 修复测试:同步 SoraGenerationRepository 接口、SQL INSERT 参数、scan 字段
parent 669bff78
//go:build unit
package admin
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
func float64Ptr(v float64) *float64 { return &v }
func intPtr(v int) *int { return &v }
// ---------------------------------------------------------------------------
// 1. channelToResponse
// ---------------------------------------------------------------------------
func TestChannelToResponse_NilInput(t *testing.T) {
require.Nil(t, channelToResponse(nil))
}
func TestChannelToResponse_FullChannel(t *testing.T) {
now := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC)
ch := &service.Channel{
ID: 42,
Name: "test-channel",
Description: "desc",
Status: "active",
BillingModelSource: "upstream",
RestrictModels: true,
CreatedAt: now,
UpdatedAt: now.Add(time.Hour),
GroupIDs: []int64{1, 2, 3},
ModelPricing: []service.ChannelModelPricing{
{
ID: 10,
Platform: "openai",
Models: []string{"gpt-4"},
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.03),
CacheWritePrice: float64Ptr(0.005),
CacheReadPrice: float64Ptr(0.002),
PerRequestPrice: float64Ptr(0.5),
},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-3-haiku": "claude-haiku-3"},
},
}
resp := channelToResponse(ch)
require.NotNil(t, resp)
require.Equal(t, int64(42), resp.ID)
require.Equal(t, "test-channel", resp.Name)
require.Equal(t, "desc", resp.Description)
require.Equal(t, "active", resp.Status)
require.Equal(t, "upstream", resp.BillingModelSource)
require.True(t, resp.RestrictModels)
require.Equal(t, []int64{1, 2, 3}, resp.GroupIDs)
require.Equal(t, "2025-06-01T12:00:00Z", resp.CreatedAt)
require.Equal(t, "2025-06-01T13:00:00Z", resp.UpdatedAt)
// model mapping
require.Len(t, resp.ModelMapping, 1)
require.Equal(t, "claude-haiku-3", resp.ModelMapping["anthropic"]["claude-3-haiku"])
// pricing
require.Len(t, resp.ModelPricing, 1)
p := resp.ModelPricing[0]
require.Equal(t, int64(10), p.ID)
require.Equal(t, "openai", p.Platform)
require.Equal(t, []string{"gpt-4"}, p.Models)
require.Equal(t, "token", p.BillingMode)
require.Equal(t, float64Ptr(0.01), p.InputPrice)
require.Equal(t, float64Ptr(0.03), p.OutputPrice)
require.Equal(t, float64Ptr(0.005), p.CacheWritePrice)
require.Equal(t, float64Ptr(0.002), p.CacheReadPrice)
require.Equal(t, float64Ptr(0.5), p.PerRequestPrice)
require.Empty(t, p.Intervals)
}
func TestChannelToResponse_EmptyDefaults(t *testing.T) {
now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
ch := &service.Channel{
ID: 1,
Name: "ch",
BillingModelSource: "",
CreatedAt: now,
UpdatedAt: now,
GroupIDs: nil,
ModelMapping: nil,
ModelPricing: []service.ChannelModelPricing{
{
Platform: "",
BillingMode: "",
Models: []string{"m1"},
},
},
}
resp := channelToResponse(ch)
require.Equal(t, "requested", resp.BillingModelSource)
require.NotNil(t, resp.GroupIDs)
require.Empty(t, resp.GroupIDs)
require.NotNil(t, resp.ModelMapping)
require.Empty(t, resp.ModelMapping)
require.Len(t, resp.ModelPricing, 1)
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
}
func TestChannelToResponse_NilModels(t *testing.T) {
now := time.Now()
ch := &service.Channel{
ID: 1,
Name: "ch",
CreatedAt: now,
UpdatedAt: now,
ModelPricing: []service.ChannelModelPricing{
{
Models: nil,
},
},
}
resp := channelToResponse(ch)
require.Len(t, resp.ModelPricing, 1)
require.NotNil(t, resp.ModelPricing[0].Models)
require.Empty(t, resp.ModelPricing[0].Models)
}
func TestChannelToResponse_WithIntervals(t *testing.T) {
now := time.Now()
ch := &service.Channel{
ID: 1,
Name: "ch",
CreatedAt: now,
UpdatedAt: now,
ModelPricing: []service.ChannelModelPricing{
{
Models: []string{"m1"},
BillingMode: service.BillingModePerRequest,
Intervals: []service.PricingInterval{
{
ID: 100,
MinTokens: 0,
MaxTokens: intPtr(1000),
TierLabel: "1K",
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.02),
CacheWritePrice: float64Ptr(0.003),
CacheReadPrice: float64Ptr(0.001),
PerRequestPrice: float64Ptr(0.1),
SortOrder: 1,
},
{
ID: 101,
MinTokens: 1000,
MaxTokens: nil,
TierLabel: "unlimited",
SortOrder: 2,
},
},
},
},
}
resp := channelToResponse(ch)
require.Len(t, resp.ModelPricing, 1)
intervals := resp.ModelPricing[0].Intervals
require.Len(t, intervals, 2)
iv0 := intervals[0]
require.Equal(t, int64(100), iv0.ID)
require.Equal(t, 0, iv0.MinTokens)
require.Equal(t, intPtr(1000), iv0.MaxTokens)
require.Equal(t, "1K", iv0.TierLabel)
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
require.Equal(t, 1, iv0.SortOrder)
iv1 := intervals[1]
require.Equal(t, int64(101), iv1.ID)
require.Equal(t, 1000, iv1.MinTokens)
require.Nil(t, iv1.MaxTokens)
require.Equal(t, "unlimited", iv1.TierLabel)
require.Equal(t, 2, iv1.SortOrder)
}
func TestChannelToResponse_MultipleEntries(t *testing.T) {
now := time.Now()
ch := &service.Channel{
ID: 1,
Name: "multi",
CreatedAt: now,
UpdatedAt: now,
ModelPricing: []service.ChannelModelPricing{
{
ID: 1,
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.003),
OutputPrice: float64Ptr(0.015),
},
{
ID: 2,
Platform: "openai",
Models: []string{"gpt-4", "gpt-4o"},
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(1.0),
},
{
ID: 3,
Platform: "gemini",
Models: []string{"gemini-2.5-pro"},
BillingMode: service.BillingModeImage,
ImageOutputPrice: float64Ptr(0.05),
PerRequestPrice: float64Ptr(0.2),
},
},
}
resp := channelToResponse(ch)
require.Len(t, resp.ModelPricing, 3)
require.Equal(t, int64(1), resp.ModelPricing[0].ID)
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
require.Equal(t, []string{"claude-sonnet-4"}, resp.ModelPricing[0].Models)
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
require.Equal(t, int64(2), resp.ModelPricing[1].ID)
require.Equal(t, "openai", resp.ModelPricing[1].Platform)
require.Equal(t, []string{"gpt-4", "gpt-4o"}, resp.ModelPricing[1].Models)
require.Equal(t, "per_request", resp.ModelPricing[1].BillingMode)
require.Equal(t, int64(3), resp.ModelPricing[2].ID)
require.Equal(t, "gemini", resp.ModelPricing[2].Platform)
require.Equal(t, []string{"gemini-2.5-pro"}, resp.ModelPricing[2].Models)
require.Equal(t, "image", resp.ModelPricing[2].BillingMode)
require.Equal(t, float64Ptr(0.05), resp.ModelPricing[2].ImageOutputPrice)
}
// ---------------------------------------------------------------------------
// 2. pricingRequestToService
// ---------------------------------------------------------------------------
func TestPricingRequestToService_Defaults(t *testing.T) {
tests := []struct {
name string
req channelModelPricingRequest
wantField string // which default field to check
wantValue string
}{
{
name: "empty billing mode defaults to token",
req: channelModelPricingRequest{
Models: []string{"m1"},
BillingMode: "",
},
wantField: "BillingMode",
wantValue: string(service.BillingModeToken),
},
{
name: "empty platform defaults to anthropic",
req: channelModelPricingRequest{
Models: []string{"m1"},
Platform: "",
},
wantField: "Platform",
wantValue: "anthropic",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := pricingRequestToService([]channelModelPricingRequest{tt.req})
require.Len(t, result, 1)
switch tt.wantField {
case "BillingMode":
require.Equal(t, service.BillingMode(tt.wantValue), result[0].BillingMode)
case "Platform":
require.Equal(t, tt.wantValue, result[0].Platform)
}
})
}
}
func TestPricingRequestToService_WithAllFields(t *testing.T) {
reqs := []channelModelPricingRequest{
{
Platform: "openai",
Models: []string{"gpt-4", "gpt-4o"},
BillingMode: "per_request",
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.03),
CacheWritePrice: float64Ptr(0.005),
CacheReadPrice: float64Ptr(0.002),
ImageOutputPrice: float64Ptr(0.04),
PerRequestPrice: float64Ptr(0.5),
},
}
result := pricingRequestToService(reqs)
require.Len(t, result, 1)
r := result[0]
require.Equal(t, "openai", r.Platform)
require.Equal(t, []string{"gpt-4", "gpt-4o"}, r.Models)
require.Equal(t, service.BillingModePerRequest, r.BillingMode)
require.Equal(t, float64Ptr(0.01), r.InputPrice)
require.Equal(t, float64Ptr(0.03), r.OutputPrice)
require.Equal(t, float64Ptr(0.005), r.CacheWritePrice)
require.Equal(t, float64Ptr(0.002), r.CacheReadPrice)
require.Equal(t, float64Ptr(0.04), r.ImageOutputPrice)
require.Equal(t, float64Ptr(0.5), r.PerRequestPrice)
}
func TestPricingRequestToService_WithIntervals(t *testing.T) {
reqs := []channelModelPricingRequest{
{
Models: []string{"m1"},
BillingMode: "per_request",
Intervals: []pricingIntervalRequest{
{
MinTokens: 0,
MaxTokens: intPtr(2000),
TierLabel: "small",
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.02),
CacheWritePrice: float64Ptr(0.003),
CacheReadPrice: float64Ptr(0.001),
PerRequestPrice: float64Ptr(0.1),
SortOrder: 1,
},
{
MinTokens: 2000,
MaxTokens: nil,
TierLabel: "large",
SortOrder: 2,
},
},
},
}
result := pricingRequestToService(reqs)
require.Len(t, result, 1)
require.Len(t, result[0].Intervals, 2)
iv0 := result[0].Intervals[0]
require.Equal(t, 0, iv0.MinTokens)
require.Equal(t, intPtr(2000), iv0.MaxTokens)
require.Equal(t, "small", iv0.TierLabel)
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
require.Equal(t, 1, iv0.SortOrder)
iv1 := result[0].Intervals[1]
require.Equal(t, 2000, iv1.MinTokens)
require.Nil(t, iv1.MaxTokens)
require.Equal(t, "large", iv1.TierLabel)
require.Equal(t, 2, iv1.SortOrder)
}
func TestPricingRequestToService_EmptySlice(t *testing.T) {
result := pricingRequestToService([]channelModelPricingRequest{})
require.NotNil(t, result)
require.Empty(t, result)
}
func TestPricingRequestToService_NilPriceFields(t *testing.T) {
reqs := []channelModelPricingRequest{
{
Models: []string{"m1"},
BillingMode: "token",
// all price fields are nil by default
},
}
result := pricingRequestToService(reqs)
require.Len(t, result, 1)
r := result[0]
require.Nil(t, r.InputPrice)
require.Nil(t, r.OutputPrice)
require.Nil(t, r.CacheWritePrice)
require.Nil(t, r.CacheReadPrice)
require.Nil(t, r.ImageOutputPrice)
require.Nil(t, r.PerRequestPrice)
}
// ---------------------------------------------------------------------------
// 3. validatePricingBillingMode
// ---------------------------------------------------------------------------
func TestValidatePricingBillingMode(t *testing.T) {
tests := []struct {
name string
pricing []service.ChannelModelPricing
wantErr bool
}{
{
name: "token mode - valid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeToken},
},
wantErr: false,
},
{
name: "per_request with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
},
wantErr: false,
},
{
name: "per_request with intervals - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
Intervals: []service.PricingInterval{
{MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)},
},
},
},
wantErr: false,
},
{
name: "per_request no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModePerRequest},
},
wantErr: true,
},
{
name: "image with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeImage,
PerRequestPrice: float64Ptr(0.2),
},
},
wantErr: false,
},
{
name: "image no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeImage},
},
wantErr: true,
},
{
name: "empty list - valid",
pricing: []service.ChannelModelPricing{},
wantErr: false,
},
{
name: "mixed modes with invalid image - invalid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.01),
},
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
{
BillingMode: service.BillingModeImage,
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePricingBillingMode(tt.pricing)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), "Per-request price or intervals required")
} else {
require.NoError(t, err)
}
})
}
}
...@@ -636,6 +636,40 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) { ...@@ -636,6 +636,40 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
dim.Endpoint = c.Query("endpoint") dim.Endpoint = c.Query("endpoint")
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound") dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
// Additional filter conditions
if v := c.Query("user_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.UserID = id
}
}
if v := c.Query("api_key_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.APIKeyID = id
}
}
if v := c.Query("account_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.AccountID = id
}
}
if v := c.Query("request_type"); v != "" {
if rt, err := strconv.ParseInt(v, 10, 16); err == nil {
rtVal := int16(rt)
dim.RequestType = &rtVal
}
}
if v := c.Query("stream"); v != "" {
if s, err := strconv.ParseBool(v); err == nil {
dim.Stream = &s
}
}
if v := c.Query("billing_type"); v != "" {
if bt, err := strconv.ParseInt(v, 10, 8); err == nil {
btVal := int8(bt)
dim.BillingType = &btVal
}
}
limit := 50 limit := 50
if v := c.Query("limit"); v != "" { if v := c.Query("limit"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 { if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
......
...@@ -485,10 +485,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -485,10 +485,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID, ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.gateway.messages"), zap.String("component", "handler.gateway.messages"),
...@@ -828,10 +825,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -828,10 +825,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID, ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.gateway.messages"), zap.String("component", "handler.gateway.messages"),
......
...@@ -266,10 +266,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -266,10 +266,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID, ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
reqLog.Error("gateway.cc.record_usage_failed", reqLog.Error("gateway.cc.record_usage_failed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
......
...@@ -272,10 +272,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { ...@@ -272,10 +272,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID, ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
reqLog.Error("gateway.responses.record_usage_failed", reqLog.Error("gateway.responses.record_usage_failed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
......
...@@ -534,10 +534,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -534,10 +534,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
LongContextMultiplier: 2.0, // 超出部分双倍计费 LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fs.ForceCacheBilling, ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID, ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.gemini_v1beta.models"), zap.String("component", "handler.gemini_v1beta.models"),
......
...@@ -278,10 +278,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { ...@@ -278,10 +278,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
UserAgent: userAgent, UserAgent: userAgent,
IPAddress: clientIP, IPAddress: clientIP,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID, ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"), zap.String("component", "handler.openai_gateway.chat_completions"),
......
...@@ -391,10 +391,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -391,10 +391,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID, ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.responses"), zap.String("component", "handler.openai_gateway.responses"),
...@@ -787,10 +784,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ...@@ -787,10 +784,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash, RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMappingMsg.ChannelID, ChannelUsageFields: channelMappingMsg.ToUsageFields(reqModel, result.UpstreamModel),
OriginalModel: reqModel,
BillingModelSource: channelMappingMsg.BillingModelSource,
ModelMappingChain: channelMappingMsg.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
logger.L().With( logger.L().With(
zap.String("component", "handler.openai_gateway.messages"), zap.String("component", "handler.openai_gateway.messages"),
...@@ -1298,10 +1292,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { ...@@ -1298,10 +1292,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
IPAddress: clientIP, IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService, APIKeyService: h.apiKeyService,
ChannelID: channelMappingWS.ChannelID, ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
OriginalModel: reqModel,
BillingModelSource: channelMappingWS.BillingModelSource,
ModelMappingChain: channelMappingWS.BuildModelMappingChain(reqModel, result.UpstreamModel),
}); err != nil { }); err != nil {
reqLog.Error("openai.websocket_record_usage_failed", reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID), zap.Int64("account_id", account.ID),
......
...@@ -125,6 +125,13 @@ func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []s ...@@ -125,6 +125,13 @@ func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []s
return r.countValue, nil return r.countValue, nil
} }
func (r *stubSoraGenRepo) CountByStorageType(_ context.Context, _ string, _ []string) (int64, error) {
if r.countErr != nil {
return 0, r.countErr
}
return r.countValue, nil
}
// ==================== 辅助函数 ==================== // ==================== 辅助函数 ====================
func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler { func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
...@@ -1657,8 +1664,8 @@ func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) { ...@@ -1657,8 +1664,8 @@ func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) {
fakeS3 := newFakeS3Server("ok") fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close() defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage} h := &SoraClientHandler{objectStorage: objectStorage}
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
...@@ -1679,8 +1686,8 @@ func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) { ...@@ -1679,8 +1686,8 @@ func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) {
fakeS3 := newFakeS3Server("ok") fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close() defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage} h := &SoraClientHandler{objectStorage: objectStorage}
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
...@@ -1704,8 +1711,8 @@ func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) { ...@@ -1704,8 +1711,8 @@ func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) {
})) }))
defer badSource.Close() defer badSource.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage} h := &SoraClientHandler{objectStorage: objectStorage}
_, _, storageType, _, _ := h.storeMediaWithDegradation( _, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil, context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
...@@ -1719,8 +1726,8 @@ func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) { ...@@ -1719,8 +1726,8 @@ func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) {
fakeS3 := newFakeS3Server("fail") fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close() defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage} h := &SoraClientHandler{objectStorage: objectStorage}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
...@@ -1736,8 +1743,8 @@ func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) { ...@@ -1736,8 +1743,8 @@ func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) {
fakeS3 := newFakeS3Server("fail-second") fakeS3 := newFakeS3Server("fail-second")
defer fakeS3.Close() defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage} h := &SoraClientHandler{objectStorage: objectStorage}
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
...@@ -1808,7 +1815,7 @@ func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) { ...@@ -1808,7 +1815,7 @@ func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
fakeS3 := newFakeS3Server("fail") fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close() defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
cfg := &config.Config{ cfg := &config.Config{
Sora: config.SoraConfig{ Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{ Storage: config.SoraStorageConfig{
...@@ -1821,7 +1828,7 @@ func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) { ...@@ -1821,7 +1828,7 @@ func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
} }
mediaStorage := service.NewSoraMediaStorage(cfg) mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{ h := &SoraClientHandler{
s3Storage: s3Storage, objectStorage: objectStorage,
mediaStorage: mediaStorage, mediaStorage: mediaStorage,
} }
...@@ -1846,9 +1853,9 @@ func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) { ...@@ -1846,9 +1853,9 @@ func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) {
StorageType: "upstream", StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4", MediaURL: sourceServer.URL + "/v.mp4",
} }
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil) genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}} c.Params = gin.Params{{Key: "id", Value: "1"}}
...@@ -1872,9 +1879,9 @@ func TestSaveToStorage_UpstreamURLExpired(t *testing.T) { ...@@ -1872,9 +1879,9 @@ func TestSaveToStorage_UpstreamURLExpired(t *testing.T) {
StorageType: "upstream", StorageType: "upstream",
MediaURL: expiredServer.URL + "/v.mp4", MediaURL: expiredServer.URL + "/v.mp4",
} }
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil) genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}} c.Params = gin.Params{{Key: "id", Value: "1"}}
...@@ -1896,9 +1903,9 @@ func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) { ...@@ -1896,9 +1903,9 @@ func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
StorageType: "upstream", StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4", MediaURL: sourceServer.URL + "/v.mp4",
} }
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil) genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}} c.Params = gin.Params{{Key: "id", Value: "1"}}
...@@ -1906,7 +1913,7 @@ func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) { ...@@ -1906,7 +1913,7 @@ func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec) resp := parseResponse(t, rec)
data := resp["data"].(map[string]any) data := resp["data"].(map[string]any)
require.Contains(t, data["message"], "S3") require.Contains(t, data["message"], "云存储")
require.NotEmpty(t, data["object_key"]) require.NotEmpty(t, data["object_key"])
// 验证记录已更新为 S3 存储 // 验证记录已更新为 S3 存储
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
...@@ -1928,9 +1935,9 @@ func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) { ...@@ -1928,9 +1935,9 @@ func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) {
sourceServer.URL + "/v2.mp4", sourceServer.URL + "/v2.mp4",
}, },
} }
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil) genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}} c.Params = gin.Params{{Key: "id", Value: "1"}}
...@@ -1956,7 +1963,7 @@ func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) { ...@@ -1956,7 +1963,7 @@ func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
StorageType: "upstream", StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4", MediaURL: sourceServer.URL + "/v.mp4",
} }
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil) genService := service.NewSoraGenerationService(repo, nil, nil)
userRepo := newStubUserRepoForHandler() userRepo := newStubUserRepoForHandler()
...@@ -1966,7 +1973,7 @@ func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) { ...@@ -1966,7 +1973,7 @@ func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
SoraStorageUsedBytes: 0, SoraStorageUsedBytes: 0,
} }
quotaService := service.NewSoraQuotaService(userRepo, nil, nil) quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}} c.Params = gin.Params{{Key: "id", Value: "1"}}
...@@ -1990,9 +1997,9 @@ func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) { ...@@ -1990,9 +1997,9 @@ func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) {
} }
// S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败 // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
repo.updateErr = fmt.Errorf("db error") repo.updateErr = fmt.Errorf("db error")
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil) genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}} c.Params = gin.Params{{Key: "id", Value: "1"}}
...@@ -2007,8 +2014,8 @@ func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) { ...@@ -2007,8 +2014,8 @@ func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) {
fakeS3 := newFakeS3Server("fail") fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close() defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage} h := &SoraClientHandler{objectStorage: objectStorage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c) h.GetStorageStatus(c)
...@@ -2023,8 +2030,8 @@ func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) { ...@@ -2023,8 +2030,8 @@ func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) {
fakeS3 := newFakeS3Server("ok") fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close() defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage} h := &SoraClientHandler{objectStorage: objectStorage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c) h.GetStorageStatus(c)
...@@ -2453,7 +2460,7 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { ...@@ -2453,7 +2460,7 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
}, },
} }
soraGatewayService := newMinimalSoraGatewayService(soraClient) soraGatewayService := newMinimalSoraGatewayService(soraClient)
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
userRepo := newStubUserRepoForHandler() userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{ userRepo.users[1] = &service.User{
...@@ -2465,7 +2472,7 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { ...@@ -2465,7 +2472,7 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
genService: genService, genService: genService,
gatewayService: gatewayService, gatewayService: gatewayService,
soraGatewayService: soraGatewayService, soraGatewayService: soraGatewayService,
s3Storage: s3Storage, objectStorage: objectStorage,
quotaService: quotaService, quotaService: quotaService,
} }
...@@ -2515,7 +2522,7 @@ func TestProcessGeneration_MarkCompletedFails(t *testing.T) { ...@@ -2515,7 +2522,7 @@ func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
// ==================== cleanupStoredMedia 直接测试 ==================== // ==================== cleanupStoredMedia 直接测试 ====================
func TestCleanupStoredMedia_S3Path(t *testing.T) { func TestCleanupStoredMedia_S3Path(t *testing.T) {
// S3 清理路径:s3Storage 为 nil 时不 panic // S3 清理路径:objectStorage 为 nil 时不 panic
h := &SoraClientHandler{} h := &SoraClientHandler{}
// 不应 panic // 不应 panic
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
...@@ -2962,7 +2969,7 @@ func TestSaveToStorage_QuotaExceeded(t *testing.T) { ...@@ -2962,7 +2969,7 @@ func TestSaveToStorage_QuotaExceeded(t *testing.T) {
StorageType: "upstream", StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4", MediaURL: sourceServer.URL + "/v.mp4",
} }
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil) genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户配额已满 // 用户配额已满
...@@ -2973,7 +2980,7 @@ func TestSaveToStorage_QuotaExceeded(t *testing.T) { ...@@ -2973,7 +2980,7 @@ func TestSaveToStorage_QuotaExceeded(t *testing.T) {
SoraStorageUsedBytes: 10, SoraStorageUsedBytes: 10,
} }
quotaService := service.NewSoraQuotaService(userRepo, nil, nil) quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}} c.Params = gin.Params{{Key: "id", Value: "1"}}
...@@ -2995,13 +3002,13 @@ func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) { ...@@ -2995,13 +3002,13 @@ func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) {
StorageType: "upstream", StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4", MediaURL: sourceServer.URL + "/v.mp4",
} }
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil) genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户不存在 → GetByID 失败 → AddUsage 返回普通 error // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
userRepo := newStubUserRepoForHandler() userRepo := newStubUserRepoForHandler()
quotaService := service.NewSoraQuotaService(userRepo, nil, nil) quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}} c.Params = gin.Params{{Key: "id", Value: "1"}}
...@@ -3022,9 +3029,9 @@ func TestSaveToStorage_EmptyMediaURLs(t *testing.T) { ...@@ -3022,9 +3029,9 @@ func TestSaveToStorage_EmptyMediaURLs(t *testing.T) {
MediaURL: "", MediaURL: "",
MediaURLs: []string{}, MediaURLs: []string{},
} }
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil) genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}} c.Params = gin.Params{{Key: "id", Value: "1"}}
...@@ -3049,9 +3056,9 @@ func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) { ...@@ -3049,9 +3056,9 @@ func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) {
MediaURL: sourceServer.URL + "/v1.mp4", MediaURL: sourceServer.URL + "/v1.mp4",
MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"}, MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
} }
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil) genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}} c.Params = gin.Params{{Key: "id", Value: "1"}}
...@@ -3074,7 +3081,7 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { ...@@ -3074,7 +3081,7 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
MediaURL: sourceServer.URL + "/v.mp4", MediaURL: sourceServer.URL + "/v.mp4",
} }
repo.updateErr = fmt.Errorf("db error") repo.updateErr = fmt.Errorf("db error")
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil) genService := service.NewSoraGenerationService(repo, nil, nil)
userRepo := newStubUserRepoForHandler() userRepo := newStubUserRepoForHandler()
...@@ -3084,7 +3091,7 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { ...@@ -3084,7 +3091,7 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
SoraStorageUsedBytes: 0, SoraStorageUsedBytes: 0,
} }
quotaService := service.NewSoraQuotaService(userRepo, nil, nil) quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}} c.Params = gin.Params{{Key: "id", Value: "1"}}
...@@ -3097,8 +3104,8 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { ...@@ -3097,8 +3104,8 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) { func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
fakeS3 := newFakeS3Server("ok") fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close() defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage} h := &SoraClientHandler{objectStorage: objectStorage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil) h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
} }
...@@ -3106,8 +3113,8 @@ func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) { ...@@ -3106,8 +3113,8 @@ func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) { func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
fakeS3 := newFakeS3Server("fail") fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close() defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL) objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage} h := &SoraClientHandler{objectStorage: objectStorage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
} }
......
...@@ -30,6 +30,8 @@ import ( ...@@ -30,6 +30,8 @@ import (
) )
// SoraGatewayHandler handles Sora chat completions requests // SoraGatewayHandler handles Sora chat completions requests
//
// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。
type SoraGatewayHandler struct { type SoraGatewayHandler struct {
gatewayService *service.GatewayService gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService soraGatewayService *service.SoraGatewayService
......
...@@ -175,6 +175,13 @@ type UserBreakdownDimension struct { ...@@ -175,6 +175,13 @@ type UserBreakdownDimension struct {
ModelType string // "requested", "upstream", or "mapping" ModelType string // "requested", "upstream", or "mapping"
Endpoint string // filter by endpoint value (non-empty to enable) Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path" EndpointType string // "inbound", "upstream", or "path"
// Additional filter conditions
UserID int64 // filter by user_id (>0 to enable)
APIKeyID int64 // filter by api_key_id (>0 to enable)
AccountID int64 // filter by account_id (>0 to enable)
RequestType *int16 // filter by request_type (non-nil to enable)
Stream *bool // filter by stream flag (non-nil to enable)
BillingType *int8 // filter by billing_type (non-nil to enable)
} }
// APIKeyUsageTrendPoint represents API key usage trend data point // APIKeyUsageTrendPoint represents API key usage trend data point
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strings" "strings"
...@@ -274,7 +275,8 @@ func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pr ...@@ -274,7 +275,8 @@ func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pr
// isUniqueViolation 检查 pq 唯一约束违反错误 // isUniqueViolation 检查 pq 唯一约束违反错误
func isUniqueViolation(err error) bool { func isUniqueViolation(err error) bool {
if pqErr, ok := err.(*pq.Error); ok { var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr != nil {
return pqErr.Code == "23505" return pqErr.Code == "23505"
} }
return false return false
......
//go:build unit
package repository
import (
"encoding/json"
"errors"
"fmt"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
)
// --- marshalModelMapping ---
func TestMarshalModelMapping(t *testing.T) {
tests := []struct {
name string
input map[string]map[string]string
wantJSON string // expected JSON output (exact match)
}{
{
name: "empty map",
input: map[string]map[string]string{},
wantJSON: "{}",
},
{
name: "nil map",
input: nil,
wantJSON: "{}",
},
{
name: "populated map",
input: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
},
},
{
name: "nested values",
input: map[string]map[string]string{
"openai": {"*": "gpt-5.4"},
"anthropic": {"claude-old": "claude-new"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := marshalModelMapping(tt.input)
require.NoError(t, err)
if tt.wantJSON != "" {
require.Equal(t, []byte(tt.wantJSON), result)
} else {
// round-trip: unmarshal and compare with input
var parsed map[string]map[string]string
require.NoError(t, json.Unmarshal(result, &parsed))
require.Equal(t, tt.input, parsed)
}
})
}
}
// --- unmarshalModelMapping ---
func TestUnmarshalModelMapping(t *testing.T) {
tests := []struct {
name string
input []byte
wantNil bool
want map[string]map[string]string
}{
{
name: "nil data",
input: nil,
wantNil: true,
},
{
name: "empty data",
input: []byte{},
wantNil: true,
},
{
name: "invalid JSON",
input: []byte("not-json"),
wantNil: true,
},
{
name: "type error - number",
input: []byte("42"),
wantNil: true,
},
{
name: "type error - array",
input: []byte("[1,2,3]"),
wantNil: true,
},
{
name: "valid JSON",
input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`),
want: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
"anthropic": {"old": "new"},
},
},
{
name: "empty object",
input: []byte("{}"),
want: map[string]map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := unmarshalModelMapping(tt.input)
if tt.wantNil {
require.Nil(t, result)
} else {
require.NotNil(t, result)
require.Equal(t, tt.want, result)
}
})
}
}
// --- escapeLike ---
func TestEscapeLike(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "no special chars",
input: "hello",
want: "hello",
},
{
name: "backslash",
input: `a\b`,
want: `a\\b`,
},
{
name: "percent",
input: "50%",
want: `50\%`,
},
{
name: "underscore",
input: "a_b",
want: `a\_b`,
},
{
name: "all special chars",
input: `a\b%c_d`,
want: `a\\b\%c\_d`,
},
{
name: "empty string",
input: "",
want: "",
},
{
name: "consecutive special chars",
input: "%_%",
want: `\%\_\%`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, escapeLike(tt.input))
})
}
}
// --- isUniqueViolation ---
func TestIsUniqueViolation(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "unique violation code 23505",
err: &pq.Error{Code: "23505"},
want: true,
},
{
name: "different pq error code",
err: &pq.Error{Code: "23503"},
want: false,
},
{
name: "non-pq error",
err: errors.New("some generic error"),
want: false,
},
{
name: "typed nil pq.Error",
err: func() error {
var pqErr *pq.Error
return pqErr
}(),
want: false,
},
{
name: "bare nil",
err: nil,
want: false,
},
{
name: "wrapped pq error with 23505",
err: fmt.Errorf("wrapped: %w", &pq.Error{Code: "23505"}),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, isUniqueViolation(tt.err))
})
}
}
...@@ -3144,6 +3144,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim ...@@ -3144,6 +3144,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1) query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
args = append(args, dim.Endpoint) args = append(args, dim.Endpoint)
} }
if dim.UserID > 0 {
query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
args = append(args, dim.UserID)
}
if dim.APIKeyID > 0 {
query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
args = append(args, dim.APIKeyID)
}
if dim.AccountID > 0 {
query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
args = append(args, dim.AccountID)
}
if dim.RequestType != nil {
query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1)
args = append(args, *dim.RequestType)
}
if dim.Stream != nil {
query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1)
args = append(args, *dim.Stream)
}
if dim.BillingType != nil {
query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
args = append(args, *dim.BillingType)
}
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC" query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
if limit > 0 { if limit > 0 {
......
...@@ -80,6 +80,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { ...@@ -80,6 +80,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // inbound_endpoint sqlmock.AnyArg(), // inbound_endpoint
sqlmock.AnyArg(), // upstream_endpoint sqlmock.AnyArg(), // upstream_endpoint
log.CacheTTLOverridden, log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt, createdAt,
). ).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
...@@ -153,6 +157,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { ...@@ -153,6 +157,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(), sqlmock.AnyArg(),
sqlmock.AnyArg(), sqlmock.AnyArg(),
log.CacheTTLOverridden, log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt, createdAt,
). ).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt)) WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
...@@ -463,6 +471,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -463,6 +471,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
false, false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now, now,
}}) }})
require.NoError(t, err) require.NoError(t, err)
...@@ -506,6 +518,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -506,6 +518,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
false, false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now, now,
}}) }})
require.NoError(t, err) require.NoError(t, err)
...@@ -549,6 +565,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { ...@@ -549,6 +565,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
false, false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now, now,
}}) }})
require.NoError(t, err) require.NoError(t, err)
......
...@@ -175,3 +175,11 @@ func (c *Channel) Clone() *Channel { ...@@ -175,3 +175,11 @@ func (c *Channel) Clone() *Channel {
} }
return &cp return &cp
} }
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
type ChannelUsageFields struct {
ChannelID int64 // 渠道 ID(0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
BillingModelSource string // 计费模型来源:"requested" / "upstream"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}
...@@ -4,7 +4,6 @@ import ( ...@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"log/slog" "log/slog"
"sort"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
...@@ -118,6 +117,16 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str ...@@ -118,6 +117,16 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str
return reqModel + "→" + r.MappedModel return reqModel + "→" + r.MappedModel
} }
// ToUsageFields 将渠道映射结果转为使用记录字段
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
return ChannelUsageFields{
ChannelID: r.ChannelID,
OriginalModel: reqModel,
BillingModelSource: r.BillingModelSource,
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
}
}
const ( const (
channelCacheTTL = 60 * time.Second channelCacheTTL = 60 * time.Second
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
...@@ -266,19 +275,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -266,19 +275,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
} }
} }
// 通配符条目按前缀长度降序排列(最长前缀优先匹配) // 通配符条目保持配置顺序(最先匹配到优先)
for gpKey, entries := range cache.wildcardByGroupPlatform {
sort.Slice(entries, func(i, j int) bool {
return len(entries[i].prefix) > len(entries[j].prefix)
})
cache.wildcardByGroupPlatform[gpKey] = entries
}
for gpKey, entries := range cache.wildcardMappingByGP {
sort.Slice(entries, func(i, j int) bool {
return len(entries[i].prefix) > len(entries[j].prefix)
})
cache.wildcardMappingByGP[gpKey] = entries
}
s.cache.Store(cache) s.cache.Store(cache)
return cache, nil return cache, nil
...@@ -290,7 +287,7 @@ func (s *ChannelService) invalidateCache() { ...@@ -290,7 +287,7 @@ func (s *ChannelService) invalidateCache() {
s.cacheSF.Forget("channel_cache") s.cacheSF.Forget("channel_cache")
} }
// matchWildcard 在通配符定价中查找匹配项(最长前缀优先) // matchWildcard 在通配符定价中查找匹配项(最先匹配到优先)
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing { func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform} gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardByGroupPlatform[gpKey] wildcards := c.wildcardByGroupPlatform[gpKey]
...@@ -302,7 +299,7 @@ func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) ...@@ -302,7 +299,7 @@ func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string)
return nil return nil
} }
// matchWildcardMapping 在通配符映射中查找匹配项(最长前缀优先) // matchWildcardMapping 在通配符映射中查找匹配项(最先匹配到优先)
func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string { func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform} gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardMappingByGP[gpKey] wildcards := c.wildcardMappingByGP[gpKey]
...@@ -487,7 +484,10 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) ...@@ -487,7 +484,10 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
channel.BillingModelSource = BillingModelSourceRequested channel.BillingModelSource = BillingModelSourceRequested
} }
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil { if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err return nil, err
} }
...@@ -558,7 +558,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan ...@@ -558,7 +558,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
channel.BillingModelSource = input.BillingModelSource channel.BillingModelSource = input.BillingModelSource
} }
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil { if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err return nil, err
} }
...@@ -610,16 +613,79 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP ...@@ -610,16 +613,79 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP
return s.repo.List(ctx, params, status, search) return s.repo.List(ctx, params, status, search)
} }
// validateNoDuplicateModels 检查定价列表中是否有重复模型(同一平台下不允许重复) // modelEntry 表示一个模型模式条目(用于冲突检测)
func validateNoDuplicateModels(pricingList []ChannelModelPricing) error { type modelEntry struct {
seen := make(map[string]bool) pattern string // 原始模式(如 "claude-*" 或 "claude-opus-4")
prefix string // lowercase 前缀(通配符去掉 *,精确名保持原样)
wildcard bool
}
// conflictsBetween 检查两个模型模式是否冲突
func conflictsBetween(a, b modelEntry) bool {
switch {
case !a.wildcard && !b.wildcard:
return a.prefix == b.prefix
case a.wildcard && !b.wildcard:
return strings.HasPrefix(b.prefix, a.prefix)
case !a.wildcard && b.wildcard:
return strings.HasPrefix(a.prefix, b.prefix)
default:
return strings.HasPrefix(a.prefix, b.prefix) ||
strings.HasPrefix(b.prefix, a.prefix)
}
}
// toModelEntry 将模型名转换为 modelEntry
func toModelEntry(pattern string) modelEntry {
lower := strings.ToLower(pattern)
isWild := strings.HasSuffix(lower, "*")
prefix := lower
if isWild {
prefix = strings.TrimSuffix(lower, "*")
}
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
}
// validateNoConflictingModels 检查定价列表中是否有冲突模型模式(同一平台下)。
// 冲突包括:精确重复、通配符之间的前缀包含、通配符与精确名的前缀匹配。
func validateNoConflictingModels(pricingList []ChannelModelPricing) error {
byPlatform := make(map[string][]modelEntry)
for _, p := range pricingList { for _, p := range pricingList {
for _, model := range p.Models { for _, model := range p.Models {
key := p.Platform + ":" + strings.ToLower(model) byPlatform[p.Platform] = append(byPlatform[p.Platform], toModelEntry(model))
if seen[key] { }
return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries for platform '%s'", model, p.Platform)) }
for platform, entries := range byPlatform {
if err := detectConflicts(entries, platform, "MODEL_PATTERN_CONFLICT", "model patterns"); err != nil {
return err
}
}
return nil
}
// validateNoConflictingMappings 检查模型映射中是否有冲突的源模式
func validateNoConflictingMappings(mapping map[string]map[string]string) error {
for platform, platformMapping := range mapping {
entries := make([]modelEntry, 0, len(platformMapping))
for src := range platformMapping {
entries = append(entries, toModelEntry(src))
}
if err := detectConflicts(entries, platform, "MAPPING_PATTERN_CONFLICT", "mapping source patterns"); err != nil {
return err
}
}
return nil
}
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
for i := 0; i < len(entries); i++ {
for j := i + 1; j < len(entries); j++ {
if conflictsBetween(entries[i], entries[j]) {
return infraerrors.BadRequest(errCode,
fmt.Sprintf("%s '%s' and '%s' conflict in platform '%s': overlapping match range",
label, entries[i].pattern, entries[j].pattern, platform))
} }
seen[key] = true
} }
} }
return nil return nil
......
This diff is collapsed.
...@@ -8,13 +8,10 @@ import ( ...@@ -8,13 +8,10 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func channelTestPtrFloat64(v float64) *float64 { return &v }
func channelTestPtrInt(v int) *int { return &v }
func TestGetModelPricing(t *testing.T) { func TestGetModelPricing(t *testing.T) {
ch := &Channel{ ch := &Channel{
ModelPricing: []ChannelModelPricing{ ModelPricing: []ChannelModelPricing{
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)}, {ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: testPtrFloat64(3e-6)},
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest}, {ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
}, },
} }
...@@ -48,7 +45,7 @@ func TestGetModelPricing(t *testing.T) { ...@@ -48,7 +45,7 @@ func TestGetModelPricing(t *testing.T) {
func TestGetModelPricing_ReturnsCopy(t *testing.T) { func TestGetModelPricing_ReturnsCopy(t *testing.T) {
ch := &Channel{ ch := &Channel{
ModelPricing: []ChannelModelPricing{ ModelPricing: []ChannelModelPricing{
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: channelTestPtrFloat64(3e-6)}, {ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: testPtrFloat64(3e-6)},
}, },
} }
...@@ -73,8 +70,8 @@ func TestGetModelPricing_EmptyPricing(t *testing.T) { ...@@ -73,8 +70,8 @@ func TestGetModelPricing_EmptyPricing(t *testing.T) {
func TestGetIntervalForContext(t *testing.T) { func TestGetIntervalForContext(t *testing.T) {
p := &ChannelModelPricing{ p := &ChannelModelPricing{
Intervals: []PricingInterval{ Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: channelTestPtrInt(128000), InputPrice: channelTestPtrFloat64(1e-6)}, {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: channelTestPtrFloat64(2e-6)}, {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
}, },
} }
...@@ -84,12 +81,12 @@ func TestGetIntervalForContext(t *testing.T) { ...@@ -84,12 +81,12 @@ func TestGetIntervalForContext(t *testing.T) {
wantPrice *float64 wantPrice *float64
wantNil bool wantNil bool
}{ }{
{"first interval", 50000, channelTestPtrFloat64(1e-6), false}, {"first interval", 50000, testPtrFloat64(1e-6), false},
// (min, max] — 128000 在第一个区间的 max,包含,所以匹配第一个 // (min, max] — 128000 在第一个区间的 max,包含,所以匹配第一个
{"boundary: max of first (inclusive)", 128000, channelTestPtrFloat64(1e-6), false}, {"boundary: max of first (inclusive)", 128000, testPtrFloat64(1e-6), false},
// 128001 > 128000,匹配第二个区间 // 128001 > 128000,匹配第二个区间
{"boundary: just above first max", 128001, channelTestPtrFloat64(2e-6), false}, {"boundary: just above first max", 128001, testPtrFloat64(2e-6), false},
{"unbounded interval", 500000, channelTestPtrFloat64(2e-6), false}, {"unbounded interval", 500000, testPtrFloat64(2e-6), false},
// (0, max] — 0 不匹配任何区间(左开) // (0, max] — 0 不匹配任何区间(左开)
{"zero tokens: no match", 0, nil, true}, {"zero tokens: no match", 0, nil, true},
} }
...@@ -110,7 +107,7 @@ func TestGetIntervalForContext(t *testing.T) { ...@@ -110,7 +107,7 @@ func TestGetIntervalForContext(t *testing.T) {
func TestGetIntervalForContext_NoMatch(t *testing.T) { func TestGetIntervalForContext_NoMatch(t *testing.T) {
p := &ChannelModelPricing{ p := &ChannelModelPricing{
Intervals: []PricingInterval{ Intervals: []PricingInterval{
{MinTokens: 10000, MaxTokens: channelTestPtrInt(50000)}, {MinTokens: 10000, MaxTokens: testPtrInt(50000)},
}, },
} }
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
...@@ -127,9 +124,9 @@ func TestGetIntervalForContext_Empty(t *testing.T) { ...@@ -127,9 +124,9 @@ func TestGetIntervalForContext_Empty(t *testing.T) {
func TestGetTierByLabel(t *testing.T) { func TestGetTierByLabel(t *testing.T) {
p := &ChannelModelPricing{ p := &ChannelModelPricing{
Intervals: []PricingInterval{ Intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: channelTestPtrFloat64(0.04)}, {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: channelTestPtrFloat64(0.08)}, {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
{TierLabel: "HD", PerRequestPrice: channelTestPtrFloat64(0.12)}, {TierLabel: "HD", PerRequestPrice: testPtrFloat64(0.12)},
}, },
} }
...@@ -171,7 +168,7 @@ func TestChannelClone(t *testing.T) { ...@@ -171,7 +168,7 @@ func TestChannelClone(t *testing.T) {
{ {
ID: 100, ID: 100,
Models: []string{"model-a"}, Models: []string{"model-a"},
InputPrice: channelTestPtrFloat64(5e-6), InputPrice: testPtrFloat64(5e-6),
}, },
}, },
} }
...@@ -211,3 +208,102 @@ func TestChannelModelPricingClone(t *testing.T) { ...@@ -211,3 +208,102 @@ func TestChannelModelPricingClone(t *testing.T) {
cloned.Intervals[0].TierLabel = "hacked" cloned.Intervals[0].TierLabel = "hacked"
require.Equal(t, "tier1", original.Intervals[0].TierLabel) require.Equal(t, "tier1", original.Intervals[0].TierLabel)
} }
// --- BillingMode.IsValid ---
func TestBillingModeIsValid(t *testing.T) {
tests := []struct {
name string
mode BillingMode
want bool
}{
{"token", BillingModeToken, true},
{"per_request", BillingModePerRequest, true},
{"image", BillingModeImage, true},
{"empty", BillingMode(""), true},
{"unknown", BillingMode("unknown"), false},
{"random", BillingMode("xyz"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, tt.mode.IsValid())
})
}
}
// --- Channel.IsActive ---
func TestChannelIsActive(t *testing.T) {
tests := []struct {
name string
status string
want bool
}{
{"active", StatusActive, true},
{"disabled", "disabled", false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ch := &Channel{Status: tt.status}
require.Equal(t, tt.want, ch.IsActive())
})
}
}
// --- ChannelModelPricing.Clone edge cases ---
func TestChannelModelPricingClone_EdgeCases(t *testing.T) {
t.Run("nil models", func(t *testing.T) {
original := ChannelModelPricing{Models: nil}
cloned := original.Clone()
require.Nil(t, cloned.Models)
})
t.Run("nil intervals", func(t *testing.T) {
original := ChannelModelPricing{Intervals: nil}
cloned := original.Clone()
require.Nil(t, cloned.Intervals)
})
t.Run("empty models", func(t *testing.T) {
original := ChannelModelPricing{Models: []string{}}
cloned := original.Clone()
require.NotNil(t, cloned.Models)
require.Empty(t, cloned.Models)
})
}
// --- Channel.Clone edge cases ---
func TestChannelClone_EdgeCases(t *testing.T) {
t.Run("nil model mapping", func(t *testing.T) {
original := &Channel{ID: 1, ModelMapping: nil}
cloned := original.Clone()
require.Nil(t, cloned.ModelMapping)
})
t.Run("nil model pricing", func(t *testing.T) {
original := &Channel{ID: 1, ModelPricing: nil}
cloned := original.Clone()
require.Nil(t, cloned.ModelPricing)
})
t.Run("deep copy model mapping", func(t *testing.T) {
original := &Channel{
ID: 1,
ModelMapping: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
},
}
cloned := original.Clone()
// Modify the cloned nested map
cloned.ModelMapping["openai"]["gpt-4"] = "hacked"
// Original must remain unchanged
require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"])
})
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment