Commit 2555951b authored by erio's avatar erio
Browse files

feat(channel): 渠道管理全链路集成 — 模型映射、定价、限制、用量统计

- 渠道模型映射:支持精确匹配和通配符映射,按平台隔离
- 渠道模型定价:支持 token/按次/图片三种计费模式,区间分层定价
- 模型限制:渠道可限制仅允许定价列表中的模型
- 计费模型来源:支持 requested/upstream 两种计费模型选择
- 用量统计:usage_logs 新增 channel_id/model_mapping_chain/billing_tier/billing_mode 字段
- Dashboard 支持 model_source 维度(requested/upstream/mapping)查看模型统计
- 全部 gateway handler 统一接入 ResolveChannelMappingAndRestrict
- 修复测试:同步 SoraGenerationRepository 接口、SQL INSERT 参数、scan 字段
parent 669bff78
......@@ -26,37 +26,37 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s
// --- Request / Response types ---
type createChannelRequest struct {
Name string `json:"name" binding:"required,max=100"`
Description string `json:"description"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
RestrictModels bool `json:"restrict_models"`
Name string `json:"name" binding:"required,max=100"`
Description string `json:"description"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
RestrictModels bool `json:"restrict_models"`
}
type updateChannelRequest struct {
Name string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
RestrictModels *bool `json:"restrict_models"`
Name string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"`
RestrictModels *bool `json:"restrict_models"`
}
type channelModelPricingRequest struct {
Platform string `json:"platform" binding:"omitempty,max=50"`
Models []string `json:"models" binding:"required,min=1,max=100"`
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"`
CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"`
CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"`
ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"`
PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"`
Intervals []pricingIntervalRequest `json:"intervals"`
Platform string `json:"platform" binding:"omitempty,max=50"`
Models []string `json:"models" binding:"required,min=1,max=100"`
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"`
CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"`
CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"`
ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"`
PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"`
Intervals []pricingIntervalRequest `json:"intervals"`
}
type pricingIntervalRequest struct {
......@@ -72,31 +72,31 @@ type pricingIntervalRequest struct {
}
type channelResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Status string `json:"status"`
BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Status string `json:"status"`
BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"`
GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type channelModelPricingResponse struct {
ID int64 `json:"id"`
Platform string `json:"platform"`
Models []string `json:"models"`
BillingMode string `json:"billing_mode"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
ImageOutputPrice *float64 `json:"image_output_price"`
PerRequestPrice *float64 `json:"per_request_price"`
Intervals []pricingIntervalResponse `json:"intervals"`
ID int64 `json:"id"`
Platform string `json:"platform"`
Models []string `json:"models"`
BillingMode string `json:"billing_mode"`
InputPrice *float64 `json:"input_price"`
OutputPrice *float64 `json:"output_price"`
CacheWritePrice *float64 `json:"cache_write_price"`
CacheReadPrice *float64 `json:"cache_read_price"`
ImageOutputPrice *float64 `json:"image_output_price"`
PerRequestPrice *float64 `json:"per_request_price"`
Intervals []pricingIntervalResponse `json:"intervals"`
}
type pricingIntervalResponse struct {
......@@ -117,15 +117,15 @@ func channelToResponse(ch *service.Channel) *channelResponse {
return nil
}
resp := &channelResponse{
ID: ch.ID,
Name: ch.Name,
Description: ch.Description,
Status: ch.Status,
ID: ch.ID,
Name: ch.Name,
Description: ch.Description,
Status: ch.Status,
RestrictModels: ch.RestrictModels,
GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
}
resp.BillingModelSource = ch.BillingModelSource
if resp.BillingModelSource == "" {
......@@ -298,9 +298,9 @@ func (h *ChannelHandler) Create(c *gin.Context) {
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name,
Description: req.Description,
GroupIDs: req.GroupIDs,
ModelPricing: pricing,
ModelMapping: req.ModelMapping,
GroupIDs: req.GroupIDs,
ModelPricing: pricing,
ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
})
......@@ -331,8 +331,8 @@ func (h *ChannelHandler) Update(c *gin.Context) {
Name: req.Name,
Description: req.Description,
Status: req.Status,
GroupIDs: req.GroupIDs,
ModelMapping: req.ModelMapping,
GroupIDs: req.GroupIDs,
ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels,
}
......
//go:build unit
package admin
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
func float64Ptr(v float64) *float64 { return &v }
func intPtr(v int) *int { return &v }
// ---------------------------------------------------------------------------
// 1. channelToResponse
// ---------------------------------------------------------------------------
func TestChannelToResponse_NilInput(t *testing.T) {
require.Nil(t, channelToResponse(nil))
}
func TestChannelToResponse_FullChannel(t *testing.T) {
now := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC)
ch := &service.Channel{
ID: 42,
Name: "test-channel",
Description: "desc",
Status: "active",
BillingModelSource: "upstream",
RestrictModels: true,
CreatedAt: now,
UpdatedAt: now.Add(time.Hour),
GroupIDs: []int64{1, 2, 3},
ModelPricing: []service.ChannelModelPricing{
{
ID: 10,
Platform: "openai",
Models: []string{"gpt-4"},
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.03),
CacheWritePrice: float64Ptr(0.005),
CacheReadPrice: float64Ptr(0.002),
PerRequestPrice: float64Ptr(0.5),
},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-3-haiku": "claude-haiku-3"},
},
}
resp := channelToResponse(ch)
require.NotNil(t, resp)
require.Equal(t, int64(42), resp.ID)
require.Equal(t, "test-channel", resp.Name)
require.Equal(t, "desc", resp.Description)
require.Equal(t, "active", resp.Status)
require.Equal(t, "upstream", resp.BillingModelSource)
require.True(t, resp.RestrictModels)
require.Equal(t, []int64{1, 2, 3}, resp.GroupIDs)
require.Equal(t, "2025-06-01T12:00:00Z", resp.CreatedAt)
require.Equal(t, "2025-06-01T13:00:00Z", resp.UpdatedAt)
// model mapping
require.Len(t, resp.ModelMapping, 1)
require.Equal(t, "claude-haiku-3", resp.ModelMapping["anthropic"]["claude-3-haiku"])
// pricing
require.Len(t, resp.ModelPricing, 1)
p := resp.ModelPricing[0]
require.Equal(t, int64(10), p.ID)
require.Equal(t, "openai", p.Platform)
require.Equal(t, []string{"gpt-4"}, p.Models)
require.Equal(t, "token", p.BillingMode)
require.Equal(t, float64Ptr(0.01), p.InputPrice)
require.Equal(t, float64Ptr(0.03), p.OutputPrice)
require.Equal(t, float64Ptr(0.005), p.CacheWritePrice)
require.Equal(t, float64Ptr(0.002), p.CacheReadPrice)
require.Equal(t, float64Ptr(0.5), p.PerRequestPrice)
require.Empty(t, p.Intervals)
}
func TestChannelToResponse_EmptyDefaults(t *testing.T) {
now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
ch := &service.Channel{
ID: 1,
Name: "ch",
BillingModelSource: "",
CreatedAt: now,
UpdatedAt: now,
GroupIDs: nil,
ModelMapping: nil,
ModelPricing: []service.ChannelModelPricing{
{
Platform: "",
BillingMode: "",
Models: []string{"m1"},
},
},
}
resp := channelToResponse(ch)
require.Equal(t, "requested", resp.BillingModelSource)
require.NotNil(t, resp.GroupIDs)
require.Empty(t, resp.GroupIDs)
require.NotNil(t, resp.ModelMapping)
require.Empty(t, resp.ModelMapping)
require.Len(t, resp.ModelPricing, 1)
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
}
func TestChannelToResponse_NilModels(t *testing.T) {
now := time.Now()
ch := &service.Channel{
ID: 1,
Name: "ch",
CreatedAt: now,
UpdatedAt: now,
ModelPricing: []service.ChannelModelPricing{
{
Models: nil,
},
},
}
resp := channelToResponse(ch)
require.Len(t, resp.ModelPricing, 1)
require.NotNil(t, resp.ModelPricing[0].Models)
require.Empty(t, resp.ModelPricing[0].Models)
}
func TestChannelToResponse_WithIntervals(t *testing.T) {
now := time.Now()
ch := &service.Channel{
ID: 1,
Name: "ch",
CreatedAt: now,
UpdatedAt: now,
ModelPricing: []service.ChannelModelPricing{
{
Models: []string{"m1"},
BillingMode: service.BillingModePerRequest,
Intervals: []service.PricingInterval{
{
ID: 100,
MinTokens: 0,
MaxTokens: intPtr(1000),
TierLabel: "1K",
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.02),
CacheWritePrice: float64Ptr(0.003),
CacheReadPrice: float64Ptr(0.001),
PerRequestPrice: float64Ptr(0.1),
SortOrder: 1,
},
{
ID: 101,
MinTokens: 1000,
MaxTokens: nil,
TierLabel: "unlimited",
SortOrder: 2,
},
},
},
},
}
resp := channelToResponse(ch)
require.Len(t, resp.ModelPricing, 1)
intervals := resp.ModelPricing[0].Intervals
require.Len(t, intervals, 2)
iv0 := intervals[0]
require.Equal(t, int64(100), iv0.ID)
require.Equal(t, 0, iv0.MinTokens)
require.Equal(t, intPtr(1000), iv0.MaxTokens)
require.Equal(t, "1K", iv0.TierLabel)
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
require.Equal(t, 1, iv0.SortOrder)
iv1 := intervals[1]
require.Equal(t, int64(101), iv1.ID)
require.Equal(t, 1000, iv1.MinTokens)
require.Nil(t, iv1.MaxTokens)
require.Equal(t, "unlimited", iv1.TierLabel)
require.Equal(t, 2, iv1.SortOrder)
}
func TestChannelToResponse_MultipleEntries(t *testing.T) {
now := time.Now()
ch := &service.Channel{
ID: 1,
Name: "multi",
CreatedAt: now,
UpdatedAt: now,
ModelPricing: []service.ChannelModelPricing{
{
ID: 1,
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.003),
OutputPrice: float64Ptr(0.015),
},
{
ID: 2,
Platform: "openai",
Models: []string{"gpt-4", "gpt-4o"},
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(1.0),
},
{
ID: 3,
Platform: "gemini",
Models: []string{"gemini-2.5-pro"},
BillingMode: service.BillingModeImage,
ImageOutputPrice: float64Ptr(0.05),
PerRequestPrice: float64Ptr(0.2),
},
},
}
resp := channelToResponse(ch)
require.Len(t, resp.ModelPricing, 3)
require.Equal(t, int64(1), resp.ModelPricing[0].ID)
require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
require.Equal(t, []string{"claude-sonnet-4"}, resp.ModelPricing[0].Models)
require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
require.Equal(t, int64(2), resp.ModelPricing[1].ID)
require.Equal(t, "openai", resp.ModelPricing[1].Platform)
require.Equal(t, []string{"gpt-4", "gpt-4o"}, resp.ModelPricing[1].Models)
require.Equal(t, "per_request", resp.ModelPricing[1].BillingMode)
require.Equal(t, int64(3), resp.ModelPricing[2].ID)
require.Equal(t, "gemini", resp.ModelPricing[2].Platform)
require.Equal(t, []string{"gemini-2.5-pro"}, resp.ModelPricing[2].Models)
require.Equal(t, "image", resp.ModelPricing[2].BillingMode)
require.Equal(t, float64Ptr(0.05), resp.ModelPricing[2].ImageOutputPrice)
}
// ---------------------------------------------------------------------------
// 2. pricingRequestToService
// ---------------------------------------------------------------------------
func TestPricingRequestToService_Defaults(t *testing.T) {
tests := []struct {
name string
req channelModelPricingRequest
wantField string // which default field to check
wantValue string
}{
{
name: "empty billing mode defaults to token",
req: channelModelPricingRequest{
Models: []string{"m1"},
BillingMode: "",
},
wantField: "BillingMode",
wantValue: string(service.BillingModeToken),
},
{
name: "empty platform defaults to anthropic",
req: channelModelPricingRequest{
Models: []string{"m1"},
Platform: "",
},
wantField: "Platform",
wantValue: "anthropic",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := pricingRequestToService([]channelModelPricingRequest{tt.req})
require.Len(t, result, 1)
switch tt.wantField {
case "BillingMode":
require.Equal(t, service.BillingMode(tt.wantValue), result[0].BillingMode)
case "Platform":
require.Equal(t, tt.wantValue, result[0].Platform)
}
})
}
}
func TestPricingRequestToService_WithAllFields(t *testing.T) {
reqs := []channelModelPricingRequest{
{
Platform: "openai",
Models: []string{"gpt-4", "gpt-4o"},
BillingMode: "per_request",
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.03),
CacheWritePrice: float64Ptr(0.005),
CacheReadPrice: float64Ptr(0.002),
ImageOutputPrice: float64Ptr(0.04),
PerRequestPrice: float64Ptr(0.5),
},
}
result := pricingRequestToService(reqs)
require.Len(t, result, 1)
r := result[0]
require.Equal(t, "openai", r.Platform)
require.Equal(t, []string{"gpt-4", "gpt-4o"}, r.Models)
require.Equal(t, service.BillingModePerRequest, r.BillingMode)
require.Equal(t, float64Ptr(0.01), r.InputPrice)
require.Equal(t, float64Ptr(0.03), r.OutputPrice)
require.Equal(t, float64Ptr(0.005), r.CacheWritePrice)
require.Equal(t, float64Ptr(0.002), r.CacheReadPrice)
require.Equal(t, float64Ptr(0.04), r.ImageOutputPrice)
require.Equal(t, float64Ptr(0.5), r.PerRequestPrice)
}
func TestPricingRequestToService_WithIntervals(t *testing.T) {
reqs := []channelModelPricingRequest{
{
Models: []string{"m1"},
BillingMode: "per_request",
Intervals: []pricingIntervalRequest{
{
MinTokens: 0,
MaxTokens: intPtr(2000),
TierLabel: "small",
InputPrice: float64Ptr(0.01),
OutputPrice: float64Ptr(0.02),
CacheWritePrice: float64Ptr(0.003),
CacheReadPrice: float64Ptr(0.001),
PerRequestPrice: float64Ptr(0.1),
SortOrder: 1,
},
{
MinTokens: 2000,
MaxTokens: nil,
TierLabel: "large",
SortOrder: 2,
},
},
},
}
result := pricingRequestToService(reqs)
require.Len(t, result, 1)
require.Len(t, result[0].Intervals, 2)
iv0 := result[0].Intervals[0]
require.Equal(t, 0, iv0.MinTokens)
require.Equal(t, intPtr(2000), iv0.MaxTokens)
require.Equal(t, "small", iv0.TierLabel)
require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
require.Equal(t, 1, iv0.SortOrder)
iv1 := result[0].Intervals[1]
require.Equal(t, 2000, iv1.MinTokens)
require.Nil(t, iv1.MaxTokens)
require.Equal(t, "large", iv1.TierLabel)
require.Equal(t, 2, iv1.SortOrder)
}
func TestPricingRequestToService_EmptySlice(t *testing.T) {
result := pricingRequestToService([]channelModelPricingRequest{})
require.NotNil(t, result)
require.Empty(t, result)
}
func TestPricingRequestToService_NilPriceFields(t *testing.T) {
reqs := []channelModelPricingRequest{
{
Models: []string{"m1"},
BillingMode: "token",
// all price fields are nil by default
},
}
result := pricingRequestToService(reqs)
require.Len(t, result, 1)
r := result[0]
require.Nil(t, r.InputPrice)
require.Nil(t, r.OutputPrice)
require.Nil(t, r.CacheWritePrice)
require.Nil(t, r.CacheReadPrice)
require.Nil(t, r.ImageOutputPrice)
require.Nil(t, r.PerRequestPrice)
}
// ---------------------------------------------------------------------------
// 3. validatePricingBillingMode
// ---------------------------------------------------------------------------
func TestValidatePricingBillingMode(t *testing.T) {
tests := []struct {
name string
pricing []service.ChannelModelPricing
wantErr bool
}{
{
name: "token mode - valid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeToken},
},
wantErr: false,
},
{
name: "per_request with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
},
wantErr: false,
},
{
name: "per_request with intervals - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
Intervals: []service.PricingInterval{
{MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)},
},
},
},
wantErr: false,
},
{
name: "per_request no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModePerRequest},
},
wantErr: true,
},
{
name: "image with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeImage,
PerRequestPrice: float64Ptr(0.2),
},
},
wantErr: false,
},
{
name: "image no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeImage},
},
wantErr: true,
},
{
name: "empty list - valid",
pricing: []service.ChannelModelPricing{},
wantErr: false,
},
{
name: "mixed modes with invalid image - invalid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.01),
},
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
{
BillingMode: service.BillingModeImage,
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePricingBillingMode(tt.pricing)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), "Per-request price or intervals required")
} else {
require.NoError(t, err)
}
})
}
}
......@@ -636,6 +636,40 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
dim.Endpoint = c.Query("endpoint")
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
// Additional filter conditions
if v := c.Query("user_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.UserID = id
}
}
if v := c.Query("api_key_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.APIKeyID = id
}
}
if v := c.Query("account_id"); v != "" {
if id, err := strconv.ParseInt(v, 10, 64); err == nil {
dim.AccountID = id
}
}
if v := c.Query("request_type"); v != "" {
if rt, err := strconv.ParseInt(v, 10, 16); err == nil {
rtVal := int16(rt)
dim.RequestType = &rtVal
}
}
if v := c.Query("stream"); v != "" {
if s, err := strconv.ParseBool(v); err == nil {
dim.Stream = &s
}
}
if v := c.Query("billing_type"); v != "" {
if bt, err := strconv.ParseInt(v, 10, 8); err == nil {
btVal := int8(bt)
dim.BillingType = &btVal
}
}
limit := 50
if v := c.Query("limit"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
......
......@@ -485,10 +485,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),
......@@ -828,10 +825,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),
......
......@@ -266,10 +266,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("gateway.cc.record_usage_failed",
zap.Int64("account_id", account.ID),
......
......@@ -272,10 +272,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("gateway.responses.record_usage_failed",
zap.Int64("account_id", account.ID),
......
......@@ -534,10 +534,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.gemini_v1beta.models"),
......
......@@ -278,10 +278,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"),
......
......@@ -391,10 +391,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
ChannelID: channelMapping.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMapping.BillingModelSource,
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
......@@ -787,10 +784,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
ChannelID: channelMappingMsg.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMappingMsg.BillingModelSource,
ModelMappingChain: channelMappingMsg.BuildModelMappingChain(reqModel, result.UpstreamModel),
ChannelUsageFields: channelMappingMsg.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.messages"),
......@@ -1298,10 +1292,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
ChannelID: channelMappingWS.ChannelID,
OriginalModel: reqModel,
BillingModelSource: channelMappingWS.BillingModelSource,
ModelMappingChain: channelMappingWS.BuildModelMappingChain(reqModel, result.UpstreamModel),
ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
......
......@@ -125,6 +125,13 @@ func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []s
return r.countValue, nil
}
func (r *stubSoraGenRepo) CountByStorageType(_ context.Context, _ string, _ []string) (int64, error) {
if r.countErr != nil {
return 0, r.countErr
}
return r.countValue, nil
}
// ==================== 辅助函数 ====================
func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
......@@ -1657,8 +1664,8 @@ func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{objectStorage: objectStorage}
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
......@@ -1679,8 +1686,8 @@ func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{objectStorage: objectStorage}
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
......@@ -1704,8 +1711,8 @@ func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) {
}))
defer badSource.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{objectStorage: objectStorage}
_, _, storageType, _, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
......@@ -1719,8 +1726,8 @@ func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) {
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{objectStorage: objectStorage}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
......@@ -1736,8 +1743,8 @@ func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) {
fakeS3 := newFakeS3Server("fail-second")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{objectStorage: objectStorage}
urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
_, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
......@@ -1808,7 +1815,7 @@ func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
cfg := &config.Config{
Sora: config.SoraConfig{
Storage: config.SoraStorageConfig{
......@@ -1821,8 +1828,8 @@ func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
}
mediaStorage := service.NewSoraMediaStorage(cfg)
h := &SoraClientHandler{
s3Storage: s3Storage,
mediaStorage: mediaStorage,
objectStorage: objectStorage,
mediaStorage: mediaStorage,
}
_, _, storageType, _, _ := h.storeMediaWithDegradation(
......@@ -1846,9 +1853,9 @@ func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) {
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
......@@ -1872,9 +1879,9 @@ func TestSaveToStorage_UpstreamURLExpired(t *testing.T) {
StorageType: "upstream",
MediaURL: expiredServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
......@@ -1896,9 +1903,9 @@ func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
......@@ -1906,7 +1913,7 @@ func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
require.Equal(t, http.StatusOK, rec.Code)
resp := parseResponse(t, rec)
data := resp["data"].(map[string]any)
require.Contains(t, data["message"], "S3")
require.Contains(t, data["message"], "云存储")
require.NotEmpty(t, data["object_key"])
// 验证记录已更新为 S3 存储
require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
......@@ -1928,9 +1935,9 @@ func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) {
sourceServer.URL + "/v2.mp4",
},
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
......@@ -1956,7 +1963,7 @@ func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
userRepo := newStubUserRepoForHandler()
......@@ -1966,7 +1973,7 @@ func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
SoraStorageUsedBytes: 0,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
......@@ -1990,9 +1997,9 @@ func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) {
}
// S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
repo.updateErr = fmt.Errorf("db error")
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
......@@ -2007,8 +2014,8 @@ func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) {
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{objectStorage: objectStorage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
......@@ -2023,8 +2030,8 @@ func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{objectStorage: objectStorage}
c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
h.GetStorageStatus(c)
......@@ -2453,7 +2460,7 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
},
}
soraGatewayService := newMinimalSoraGatewayService(soraClient)
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
userRepo := newStubUserRepoForHandler()
userRepo.users[1] = &service.User{
......@@ -2465,7 +2472,7 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
genService: genService,
gatewayService: gatewayService,
soraGatewayService: soraGatewayService,
s3Storage: s3Storage,
objectStorage: objectStorage,
quotaService: quotaService,
}
......@@ -2515,7 +2522,7 @@ func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
// ==================== cleanupStoredMedia 直接测试 ====================
func TestCleanupStoredMedia_S3Path(t *testing.T) {
// S3 清理路径:s3Storage 为 nil 时不 panic
// S3 清理路径:objectStorage 为 nil 时不 panic
h := &SoraClientHandler{}
// 不应 panic
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
......@@ -2962,7 +2969,7 @@ func TestSaveToStorage_QuotaExceeded(t *testing.T) {
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户配额已满
......@@ -2973,7 +2980,7 @@ func TestSaveToStorage_QuotaExceeded(t *testing.T) {
SoraStorageUsedBytes: 10,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
......@@ -2995,13 +3002,13 @@ func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) {
StorageType: "upstream",
MediaURL: sourceServer.URL + "/v.mp4",
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
// 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
userRepo := newStubUserRepoForHandler()
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
......@@ -3022,9 +3029,9 @@ func TestSaveToStorage_EmptyMediaURLs(t *testing.T) {
MediaURL: "",
MediaURLs: []string{},
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
......@@ -3049,9 +3056,9 @@ func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) {
MediaURL: sourceServer.URL + "/v1.mp4",
MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
}
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
......@@ -3074,7 +3081,7 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
MediaURL: sourceServer.URL + "/v.mp4",
}
repo.updateErr = fmt.Errorf("db error")
s3Storage := newS3StorageForHandler(fakeS3.URL)
objectStorage := newS3StorageForHandler(fakeS3.URL)
genService := service.NewSoraGenerationService(repo, nil, nil)
userRepo := newStubUserRepoForHandler()
......@@ -3084,7 +3091,7 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
SoraStorageUsedBytes: 0,
}
quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService}
c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
c.Params = gin.Params{{Key: "id", Value: "1"}}
......@@ -3097,8 +3104,8 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
fakeS3 := newFakeS3Server("ok")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{objectStorage: objectStorage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
}
......@@ -3106,8 +3113,8 @@ func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
fakeS3 := newFakeS3Server("fail")
defer fakeS3.Close()
s3Storage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{s3Storage: s3Storage}
objectStorage := newS3StorageForHandler(fakeS3.URL)
h := &SoraClientHandler{objectStorage: objectStorage}
h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
}
......
......@@ -30,6 +30,8 @@ import (
)
// SoraGatewayHandler handles Sora chat completions requests
//
// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。
type SoraGatewayHandler struct {
gatewayService *service.GatewayService
soraGatewayService *service.SoraGatewayService
......
......@@ -175,6 +175,13 @@ type UserBreakdownDimension struct {
ModelType string // "requested", "upstream", or "mapping"
Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path"
// Additional filter conditions
UserID int64 // filter by user_id (>0 to enable)
APIKeyID int64 // filter by api_key_id (>0 to enable)
AccountID int64 // filter by account_id (>0 to enable)
RequestType *int16 // filter by request_type (non-nil to enable)
Stream *bool // filter by stream flag (non-nil to enable)
BillingType *int8 // filter by billing_type (non-nil to enable)
}
// APIKeyUsageTrendPoint represents API key usage trend data point
......
......@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
......@@ -274,7 +275,8 @@ func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pr
// isUniqueViolation 检查 pq 唯一约束违反错误
func isUniqueViolation(err error) bool {
if pqErr, ok := err.(*pq.Error); ok {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr != nil {
return pqErr.Code == "23505"
}
return false
......
//go:build unit
package repository
import (
"encoding/json"
"errors"
"fmt"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
)
// --- marshalModelMapping ---
func TestMarshalModelMapping(t *testing.T) {
tests := []struct {
name string
input map[string]map[string]string
wantJSON string // expected JSON output (exact match)
}{
{
name: "empty map",
input: map[string]map[string]string{},
wantJSON: "{}",
},
{
name: "nil map",
input: nil,
wantJSON: "{}",
},
{
name: "populated map",
input: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
},
},
{
name: "nested values",
input: map[string]map[string]string{
"openai": {"*": "gpt-5.4"},
"anthropic": {"claude-old": "claude-new"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := marshalModelMapping(tt.input)
require.NoError(t, err)
if tt.wantJSON != "" {
require.Equal(t, []byte(tt.wantJSON), result)
} else {
// round-trip: unmarshal and compare with input
var parsed map[string]map[string]string
require.NoError(t, json.Unmarshal(result, &parsed))
require.Equal(t, tt.input, parsed)
}
})
}
}
// --- unmarshalModelMapping ---
func TestUnmarshalModelMapping(t *testing.T) {
tests := []struct {
name string
input []byte
wantNil bool
want map[string]map[string]string
}{
{
name: "nil data",
input: nil,
wantNil: true,
},
{
name: "empty data",
input: []byte{},
wantNil: true,
},
{
name: "invalid JSON",
input: []byte("not-json"),
wantNil: true,
},
{
name: "type error - number",
input: []byte("42"),
wantNil: true,
},
{
name: "type error - array",
input: []byte("[1,2,3]"),
wantNil: true,
},
{
name: "valid JSON",
input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`),
want: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
"anthropic": {"old": "new"},
},
},
{
name: "empty object",
input: []byte("{}"),
want: map[string]map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := unmarshalModelMapping(tt.input)
if tt.wantNil {
require.Nil(t, result)
} else {
require.NotNil(t, result)
require.Equal(t, tt.want, result)
}
})
}
}
// --- escapeLike ---
func TestEscapeLike(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "no special chars",
input: "hello",
want: "hello",
},
{
name: "backslash",
input: `a\b`,
want: `a\\b`,
},
{
name: "percent",
input: "50%",
want: `50\%`,
},
{
name: "underscore",
input: "a_b",
want: `a\_b`,
},
{
name: "all special chars",
input: `a\b%c_d`,
want: `a\\b\%c\_d`,
},
{
name: "empty string",
input: "",
want: "",
},
{
name: "consecutive special chars",
input: "%_%",
want: `\%\_\%`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, escapeLike(tt.input))
})
}
}
// --- isUniqueViolation ---
func TestIsUniqueViolation(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{
name: "unique violation code 23505",
err: &pq.Error{Code: "23505"},
want: true,
},
{
name: "different pq error code",
err: &pq.Error{Code: "23503"},
want: false,
},
{
name: "non-pq error",
err: errors.New("some generic error"),
want: false,
},
{
name: "typed nil pq.Error",
err: func() error {
var pqErr *pq.Error
return pqErr
}(),
want: false,
},
{
name: "bare nil",
err: nil,
want: false,
},
{
name: "wrapped pq error with 23505",
err: fmt.Errorf("wrapped: %w", &pq.Error{Code: "23505"}),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, isUniqueViolation(tt.err))
})
}
}
......@@ -3144,6 +3144,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
args = append(args, dim.Endpoint)
}
if dim.UserID > 0 {
query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
args = append(args, dim.UserID)
}
if dim.APIKeyID > 0 {
query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
args = append(args, dim.APIKeyID)
}
if dim.AccountID > 0 {
query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
args = append(args, dim.AccountID)
}
if dim.RequestType != nil {
query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1)
args = append(args, *dim.RequestType)
}
if dim.Stream != nil {
query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1)
args = append(args, *dim.Stream)
}
if dim.BillingType != nil {
query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
args = append(args, *dim.BillingType)
}
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
if limit > 0 {
......
......@@ -80,6 +80,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // inbound_endpoint
sqlmock.AnyArg(), // upstream_endpoint
log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
......@@ -153,6 +157,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.CacheTTLOverridden,
sqlmock.AnyArg(), // channel_id
sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
......@@ -463,6 +471,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now,
}})
require.NoError(t, err)
......@@ -506,6 +518,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now,
}})
require.NoError(t, err)
......@@ -549,6 +565,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
sql.NullString{},
false,
sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode
now,
}})
require.NoError(t, err)
......
......@@ -51,15 +51,15 @@ type Channel struct {
type ChannelModelPricing struct {
ID int64
ChannelID int64
Platform string // 所属平台(anthropic/openai/gemini/...)
Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价
OutputPrice *float64 // 每 token 输出价格(USD)
CacheWritePrice *float64 // 缓存写入价格
CacheReadPrice *float64 // 缓存读取价格
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
PerRequestPrice *float64 // 默认按次计费价格(USD)
Platform string // 所属平台(anthropic/openai/gemini/...)
Models []string // 绑定的模型列表
BillingMode BillingMode // 计费模式
InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价
OutputPrice *float64 // 每 token 输出价格(USD)
CacheWritePrice *float64 // 缓存写入价格
CacheReadPrice *float64 // 缓存读取价格
ImageOutputPrice *float64 // 图片输出价格(向后兼容)
PerRequestPrice *float64 // 默认按次计费价格(USD)
Intervals []PricingInterval // 区间定价列表
CreatedAt time.Time
UpdatedAt time.Time
......@@ -175,3 +175,11 @@ func (c *Channel) Clone() *Channel {
}
return &cp
}
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
type ChannelUsageFields struct {
ChannelID int64 // 渠道 ID(0 = 无渠道)
OriginalModel string // 用户原始请求模型(渠道映射前)
BillingModelSource string // 计费模型来源:"requested" / "upstream"
ModelMappingChain string // 映射链描述,如 "a→b→c"
}
......@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log/slog"
"sort"
"strings"
"sync/atomic"
"time"
......@@ -17,8 +16,8 @@ import (
)
var (
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found")
ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists")
ErrGroupAlreadyInChannel = infraerrors.Conflict(
"GROUP_ALREADY_IN_CHANNEL",
"one or more groups already belong to another channel",
......@@ -81,12 +80,12 @@ type wildcardMappingEntry struct {
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
type channelCache struct {
// 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform
// 冷路径(CRUD 操作)
byID map[int64]*Channel
......@@ -118,9 +117,19 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str
return reqModel + "→" + r.MappedModel
}
// ToUsageFields 将渠道映射结果转为使用记录字段
func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields {
return ChannelUsageFields{
ChannelID: r.ChannelID,
OriginalModel: reqModel,
BillingModelSource: r.BillingModelSource,
ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel),
}
}
const (
channelCacheTTL = 60 * time.Second
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
channelCacheTTL = 60 * time.Second
channelErrorTTL = 5 * time.Second // DB 错误时的短缓存
channelCacheDBTimeout = 10 * time.Second
)
......@@ -177,14 +186,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err)
errorCache := &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL
}
s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err)
......@@ -205,14 +214,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
}
cache := &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: groupPlatforms,
byID: make(map[int64]*Channel, len(channels)),
loadedAt: time.Now(),
mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: groupPlatforms,
byID: make(map[int64]*Channel, len(channels)),
loadedAt: time.Now(),
}
for i := range channels {
......@@ -266,19 +275,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
}
}
// 通配符条目按前缀长度降序排列(最长前缀优先匹配)
for gpKey, entries := range cache.wildcardByGroupPlatform {
sort.Slice(entries, func(i, j int) bool {
return len(entries[i].prefix) > len(entries[j].prefix)
})
cache.wildcardByGroupPlatform[gpKey] = entries
}
for gpKey, entries := range cache.wildcardMappingByGP {
sort.Slice(entries, func(i, j int) bool {
return len(entries[i].prefix) > len(entries[j].prefix)
})
cache.wildcardMappingByGP[gpKey] = entries
}
// 通配符条目保持配置顺序(最先匹配到优先)
s.cache.Store(cache)
return cache, nil
......@@ -290,7 +287,7 @@ func (s *ChannelService) invalidateCache() {
s.cacheSF.Forget("channel_cache")
}
// matchWildcard 在通配符定价中查找匹配项(最长前缀优先)
// matchWildcard 在通配符定价中查找匹配项(最先匹配到优先)
func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardByGroupPlatform[gpKey]
......@@ -302,7 +299,7 @@ func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string)
return nil
}
// matchWildcardMapping 在通配符映射中查找匹配项(最长前缀优先)
// matchWildcardMapping 在通配符映射中查找匹配项(最先匹配到优先)
func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardMappingByGP[gpKey]
......@@ -479,15 +476,18 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
Status: StatusActive,
BillingModelSource: input.BillingModelSource,
RestrictModels: input.RestrictModels,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping,
}
if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceRequested
}
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
}
......@@ -558,7 +558,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
channel.BillingModelSource = input.BillingModelSource
}
if err := validateNoDuplicateModels(channel.ModelPricing); err != nil {
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
}
......@@ -610,16 +613,79 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP
return s.repo.List(ctx, params, status, search)
}
// validateNoDuplicateModels 检查定价列表中是否有重复模型(同一平台下不允许重复)
func validateNoDuplicateModels(pricingList []ChannelModelPricing) error {
seen := make(map[string]bool)
// modelEntry 表示一个模型模式条目(用于冲突检测)
type modelEntry struct {
pattern string // 原始模式(如 "claude-*" 或 "claude-opus-4")
prefix string // lowercase 前缀(通配符去掉 *,精确名保持原样)
wildcard bool
}
// conflictsBetween 检查两个模型模式是否冲突
func conflictsBetween(a, b modelEntry) bool {
switch {
case !a.wildcard && !b.wildcard:
return a.prefix == b.prefix
case a.wildcard && !b.wildcard:
return strings.HasPrefix(b.prefix, a.prefix)
case !a.wildcard && b.wildcard:
return strings.HasPrefix(a.prefix, b.prefix)
default:
return strings.HasPrefix(a.prefix, b.prefix) ||
strings.HasPrefix(b.prefix, a.prefix)
}
}
// toModelEntry 将模型名转换为 modelEntry
func toModelEntry(pattern string) modelEntry {
lower := strings.ToLower(pattern)
isWild := strings.HasSuffix(lower, "*")
prefix := lower
if isWild {
prefix = strings.TrimSuffix(lower, "*")
}
return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild}
}
// validateNoConflictingModels 检查定价列表中是否有冲突模型模式(同一平台下)。
// 冲突包括:精确重复、通配符之间的前缀包含、通配符与精确名的前缀匹配。
func validateNoConflictingModels(pricingList []ChannelModelPricing) error {
byPlatform := make(map[string][]modelEntry)
for _, p := range pricingList {
for _, model := range p.Models {
key := p.Platform + ":" + strings.ToLower(model)
if seen[key] {
return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries for platform '%s'", model, p.Platform))
byPlatform[p.Platform] = append(byPlatform[p.Platform], toModelEntry(model))
}
}
for platform, entries := range byPlatform {
if err := detectConflicts(entries, platform, "MODEL_PATTERN_CONFLICT", "model patterns"); err != nil {
return err
}
}
return nil
}
// validateNoConflictingMappings 检查模型映射中是否有冲突的源模式
func validateNoConflictingMappings(mapping map[string]map[string]string) error {
for platform, platformMapping := range mapping {
entries := make([]modelEntry, 0, len(platformMapping))
for src := range platformMapping {
entries = append(entries, toModelEntry(src))
}
if err := detectConflicts(entries, platform, "MAPPING_PATTERN_CONFLICT", "mapping source patterns"); err != nil {
return err
}
}
return nil
}
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
for i := 0; i < len(entries); i++ {
for j := i + 1; j < len(entries); j++ {
if conflictsBetween(entries[i], entries[j]) {
return infraerrors.BadRequest(errCode,
fmt.Sprintf("%s '%s' and '%s' conflict in platform '%s': overlapping match range",
label, entries[i].pattern, entries[j].pattern, platform))
}
seen[key] = true
}
}
return nil
......
This diff is collapsed.
......@@ -8,13 +8,10 @@ import (
"github.com/stretchr/testify/require"
)
func channelTestPtrFloat64(v float64) *float64 { return &v }
func channelTestPtrInt(v int) *int { return &v }
func TestGetModelPricing(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)},
{ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: testPtrFloat64(3e-6)},
{ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest},
},
}
......@@ -48,7 +45,7 @@ func TestGetModelPricing(t *testing.T) {
func TestGetModelPricing_ReturnsCopy(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: channelTestPtrFloat64(3e-6)},
{ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: testPtrFloat64(3e-6)},
},
}
......@@ -73,23 +70,23 @@ func TestGetModelPricing_EmptyPricing(t *testing.T) {
func TestGetIntervalForContext(t *testing.T) {
p := &ChannelModelPricing{
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: channelTestPtrInt(128000), InputPrice: channelTestPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: channelTestPtrFloat64(2e-6)},
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
},
}
tests := []struct {
name string
tokens int
wantPrice *float64
wantNil bool
name string
tokens int
wantPrice *float64
wantNil bool
}{
{"first interval", 50000, channelTestPtrFloat64(1e-6), false},
{"first interval", 50000, testPtrFloat64(1e-6), false},
// (min, max] — 128000 在第一个区间的 max,包含,所以匹配第一个
{"boundary: max of first (inclusive)", 128000, channelTestPtrFloat64(1e-6), false},
{"boundary: max of first (inclusive)", 128000, testPtrFloat64(1e-6), false},
// 128001 > 128000,匹配第二个区间
{"boundary: just above first max", 128001, channelTestPtrFloat64(2e-6), false},
{"unbounded interval", 500000, channelTestPtrFloat64(2e-6), false},
{"boundary: just above first max", 128001, testPtrFloat64(2e-6), false},
{"unbounded interval", 500000, testPtrFloat64(2e-6), false},
// (0, max] — 0 不匹配任何区间(左开)
{"zero tokens: no match", 0, nil, true},
}
......@@ -110,11 +107,11 @@ func TestGetIntervalForContext(t *testing.T) {
func TestGetIntervalForContext_NoMatch(t *testing.T) {
p := &ChannelModelPricing{
Intervals: []PricingInterval{
{MinTokens: 10000, MaxTokens: channelTestPtrInt(50000)},
{MinTokens: 10000, MaxTokens: testPtrInt(50000)},
},
}
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min
require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open)
require.NotNil(t, p.GetIntervalForContext(50000)) // 50000 <= 50000 (right-closed)
require.Nil(t, p.GetIntervalForContext(50001)) // 50001 > 50000
}
......@@ -127,9 +124,9 @@ func TestGetIntervalForContext_Empty(t *testing.T) {
func TestGetTierByLabel(t *testing.T) {
p := &ChannelModelPricing{
Intervals: []PricingInterval{
{TierLabel: "1K", PerRequestPrice: channelTestPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: channelTestPtrFloat64(0.08)},
{TierLabel: "HD", PerRequestPrice: channelTestPtrFloat64(0.12)},
{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
{TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)},
{TierLabel: "HD", PerRequestPrice: testPtrFloat64(0.12)},
},
}
......@@ -171,7 +168,7 @@ func TestChannelClone(t *testing.T) {
{
ID: 100,
Models: []string{"model-a"},
InputPrice: channelTestPtrFloat64(5e-6),
InputPrice: testPtrFloat64(5e-6),
},
},
}
......@@ -211,3 +208,102 @@ func TestChannelModelPricingClone(t *testing.T) {
cloned.Intervals[0].TierLabel = "hacked"
require.Equal(t, "tier1", original.Intervals[0].TierLabel)
}
// --- BillingMode.IsValid ---
func TestBillingModeIsValid(t *testing.T) {
tests := []struct {
name string
mode BillingMode
want bool
}{
{"token", BillingModeToken, true},
{"per_request", BillingModePerRequest, true},
{"image", BillingModeImage, true},
{"empty", BillingMode(""), true},
{"unknown", BillingMode("unknown"), false},
{"random", BillingMode("xyz"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, tt.mode.IsValid())
})
}
}
// --- Channel.IsActive ---
func TestChannelIsActive(t *testing.T) {
tests := []struct {
name string
status string
want bool
}{
{"active", StatusActive, true},
{"disabled", "disabled", false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ch := &Channel{Status: tt.status}
require.Equal(t, tt.want, ch.IsActive())
})
}
}
// --- ChannelModelPricing.Clone edge cases ---
func TestChannelModelPricingClone_EdgeCases(t *testing.T) {
t.Run("nil models", func(t *testing.T) {
original := ChannelModelPricing{Models: nil}
cloned := original.Clone()
require.Nil(t, cloned.Models)
})
t.Run("nil intervals", func(t *testing.T) {
original := ChannelModelPricing{Intervals: nil}
cloned := original.Clone()
require.Nil(t, cloned.Intervals)
})
t.Run("empty models", func(t *testing.T) {
original := ChannelModelPricing{Models: []string{}}
cloned := original.Clone()
require.NotNil(t, cloned.Models)
require.Empty(t, cloned.Models)
})
}
// --- Channel.Clone edge cases ---
func TestChannelClone_EdgeCases(t *testing.T) {
t.Run("nil model mapping", func(t *testing.T) {
original := &Channel{ID: 1, ModelMapping: nil}
cloned := original.Clone()
require.Nil(t, cloned.ModelMapping)
})
t.Run("nil model pricing", func(t *testing.T) {
original := &Channel{ID: 1, ModelPricing: nil}
cloned := original.Clone()
require.Nil(t, cloned.ModelPricing)
})
t.Run("deep copy model mapping", func(t *testing.T) {
original := &Channel{
ID: 1,
ModelMapping: map[string]map[string]string{
"openai": {"gpt-4": "gpt-4-turbo"},
},
}
cloned := original.Clone()
// Modify the cloned nested map
cloned.ModelMapping["openai"]["gpt-4"] = "hacked"
// Original must remain unchanged
require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"])
})
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment