Commit 6d3ea64a authored by erio's avatar erio
Browse files

test: add unit tests for channel pricing restriction in scheduling phase

20 test cases covering:
- billingModelForRestriction: 4 cases (requested/channel_mapped/upstream/empty)
- resolveAccountUpstreamModel: 3 cases (antigravity/unsupported/non-antigravity)
- checkChannelPricingRestriction: 10 cases (nil guards, 3 billing sources,
  RestrictModels disabled, no channel)
- isUpstreamModelRestrictedByChannel: 3 cases (restricted/allowed/unsupported)
parent 1fca2bfa
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
// --- billingModelForRestriction ---
func TestBillingModelForRestriction_Requested(t *testing.T) {
t.Parallel()
got := billingModelForRestriction(BillingModelSourceRequested, "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-5", got)
}
func TestBillingModelForRestriction_ChannelMapped(t *testing.T) {
t.Parallel()
got := billingModelForRestriction(BillingModelSourceChannelMapped, "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got)
}
func TestBillingModelForRestriction_Upstream(t *testing.T) {
t.Parallel()
got := billingModelForRestriction(BillingModelSourceUpstream, "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "", got, "upstream should return empty (per-account check needed)")
}
func TestBillingModelForRestriction_Empty(t *testing.T) {
t.Parallel()
got := billingModelForRestriction("", "claude-sonnet-4-5", "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got, "empty source defaults to channel_mapped")
}
// --- resolveAccountUpstreamModel ---
func TestResolveAccountUpstreamModel_Antigravity(t *testing.T) {
t.Parallel()
account := &Account{
Platform: PlatformAntigravity,
}
// Antigravity 平台使用 DefaultAntigravityModelMapping
got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got)
}
func TestResolveAccountUpstreamModel_Antigravity_Unsupported(t *testing.T) {
t.Parallel()
account := &Account{
Platform: PlatformAntigravity,
}
got := resolveAccountUpstreamModel(account, "totally-unknown-model")
require.Equal(t, "", got, "unsupported model should return empty")
}
func TestResolveAccountUpstreamModel_NonAntigravity(t *testing.T) {
t.Parallel()
account := &Account{
Platform: PlatformAnthropic,
}
got := resolveAccountUpstreamModel(account, "claude-sonnet-4-6")
require.Equal(t, "claude-sonnet-4-6", got, "no mapping = passthrough")
}
// --- checkChannelPricingRestriction ---
func TestCheckChannelPricingRestriction_NilGroupID(t *testing.T) {
t.Parallel()
svc := &GatewayService{channelService: &ChannelService{}}
require.False(t, svc.checkChannelPricingRestriction(context.Background(), nil, "claude-sonnet-4"))
}
func TestCheckChannelPricingRestriction_NilChannelService(t *testing.T) {
t.Parallel()
svc := &GatewayService{}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4"))
}
func TestCheckChannelPricingRestriction_EmptyModel(t *testing.T) {
t.Parallel()
svc := &GatewayService{channelService: &ChannelService{}}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, ""))
}
func TestCheckChannelPricingRestriction_ChannelMapped_Restricted(t *testing.T) {
t.Parallel()
// 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6,但定价列表只有 claude-opus-4-6
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceChannelMapped,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"mapped model claude-sonnet-4-6 is NOT in pricing → restricted")
}
func TestCheckChannelPricingRestriction_ChannelMapped_Allowed(t *testing.T) {
t.Parallel()
// 渠道映射 claude-sonnet-4-5 → claude-sonnet-4-6,定价列表包含 claude-sonnet-4-6
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceChannelMapped,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
ModelMapping: map[string]map[string]string{
"anthropic": {"claude-sonnet-4-5": "claude-sonnet-4-6"},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"mapped model claude-sonnet-4-6 IS in pricing → allowed")
}
func TestCheckChannelPricingRestriction_Requested_Restricted(t *testing.T) {
t.Parallel()
// billing_model_source=requested,定价列表有 claude-sonnet-4-6 但请求的是 claude-sonnet-4-5
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceRequested,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.True(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"requested model claude-sonnet-4-5 is NOT in pricing → restricted")
}
func TestCheckChannelPricingRestriction_Requested_Allowed(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceRequested,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-5"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "claude-sonnet-4-5"),
"requested model IS in pricing → allowed")
}
func TestCheckChannelPricingRestriction_Upstream_SkipsPreCheck(t *testing.T) {
t.Parallel()
// upstream 模式:预检查始终跳过(返回 false),需逐账号检查
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceUpstream,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "unknown-model"),
"upstream mode should skip pre-check (per-account check needed)")
}
func TestCheckChannelPricingRestriction_RestrictModelsDisabled(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: false, // 未开启模型限制
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
gid := int64(10)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"),
"RestrictModels=false → always allowed")
}
func TestCheckChannelPricingRestriction_NoChannel(t *testing.T) {
t.Parallel()
// 分组没有关联渠道
repo := &mockChannelRepository{
listAllFn: func(_ context.Context) ([]Channel, error) { return nil, nil },
}
channelSvc := newTestChannelService(repo)
svc := &GatewayService{channelService: channelSvc}
gid := int64(999)
require.False(t, svc.checkChannelPricingRestriction(context.Background(), &gid, "any-model"),
"no channel for group → allowed")
}
// --- isUpstreamModelRestrictedByChannel ---
func TestIsUpstreamModelRestrictedByChannel_Restricted(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
account := &Account{Platform: PlatformAntigravity}
// claude-sonnet-4-6 在 DefaultAntigravityModelMapping 中,映射后仍为 claude-sonnet-4-6
// 但定价列表只有 claude-opus-4-6
require.True(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"),
"upstream model claude-sonnet-4-6 NOT in pricing → restricted")
}
func TestIsUpstreamModelRestrictedByChannel_Allowed(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
account := &Account{Platform: PlatformAntigravity}
require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "claude-sonnet-4-6"),
"upstream model claude-sonnet-4-6 IS in pricing → allowed")
}
func TestIsUpstreamModelRestrictedByChannel_UnsupportedModel(t *testing.T) {
t.Parallel()
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: "anthropic", Models: []string{"claude-opus-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{10: "anthropic"}))
svc := &GatewayService{channelService: channelSvc}
account := &Account{Platform: PlatformAntigravity}
// totally-unknown-model 不在 DefaultAntigravityModelMapping 中 → 映射结果为空
require.False(t, svc.isUpstreamModelRestrictedByChannel(context.Background(), 10, account, "totally-unknown-model"),
"unmappable model → upstream model empty → not restricted (account filter handles this)")
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment