Unverified Commit 47cd1c52 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1467 from touwaeriol/refactor/channel-service-cleanup

refactor(channel): split long functions, extract shared validation, move billing validation to service
parents 06e2756e 9151d34d
package admin
import (
"errors"
"fmt"
"strconv"
"strings"
......@@ -235,61 +233,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
return result
}
// validatePricingBillingMode 校验计费配置
func validatePricingBillingMode(pricing []service.ChannelModelPricing) error {
for _, p := range pricing {
// 按次/图片模式必须配置默认价格或区间
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
return errors.New("per-request price or intervals required for per_request/image billing mode")
}
}
// 校验价格不能为负
if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil {
return err
}
if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil {
return err
}
if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil {
return err
}
// 校验 interval:至少有一个价格字段非空
for _, iv := range p.Intervals {
if iv.InputPrice == nil && iv.OutputPrice == nil &&
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
iv.PerRequestPrice == nil {
return fmt.Errorf("interval [%d, %s] has no price fields set for model %v",
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models)
}
}
}
return nil
}
func validatePriceNotNegative(field string, val *float64) error {
if val != nil && *val < 0 {
return fmt.Errorf("%s must be >= 0", field)
}
return nil
}
func formatMaxTokens(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// --- Handlers ---
// List handles listing channels with pagination
......@@ -343,10 +286,6 @@ func (h *ChannelHandler) Create(c *gin.Context) {
}
pricing := pricingRequestToService(req.ModelPricing)
if err := validatePricingBillingMode(pricing); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name,
......@@ -391,10 +330,6 @@ func (h *ChannelHandler) Update(c *gin.Context) {
}
if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing)
if err := validatePricingBillingMode(pricing); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
input.ModelPricing = &pricing
}
......
......@@ -400,103 +400,3 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) {
require.Nil(t, r.ImageOutputPrice)
require.Nil(t, r.PerRequestPrice)
}
// ---------------------------------------------------------------------------
// 3. validatePricingBillingMode
// ---------------------------------------------------------------------------
func TestValidatePricingBillingMode(t *testing.T) {
tests := []struct {
name string
pricing []service.ChannelModelPricing
wantErr bool
}{
{
name: "token mode - valid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeToken},
},
wantErr: false,
},
{
name: "per_request with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
},
wantErr: false,
},
{
name: "per_request with intervals - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
Intervals: []service.PricingInterval{
{MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)},
},
},
},
wantErr: false,
},
{
name: "per_request no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModePerRequest},
},
wantErr: true,
},
{
name: "image with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeImage,
PerRequestPrice: float64Ptr(0.2),
},
},
wantErr: false,
},
{
name: "image no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeImage},
},
wantErr: true,
},
{
name: "empty list - valid",
pricing: []service.ChannelModelPricing{},
wantErr: false,
},
{
name: "mixed modes with invalid image - invalid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.01),
},
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
{
BillingMode: service.BillingModeImage,
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePricingBillingMode(tt.pricing)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), "per-request price or intervals required")
} else {
require.NoError(t, err)
}
})
}
}
......@@ -248,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
}
}
// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。
// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。
func (s *ChannelService) storeErrorCache() {
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
s.cache.Store(errorCache)
}
// buildCache 从数据库构建渠道缓存。
// 使用独立 context 避免请求取消导致空值被长期缓存。
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
// 断开请求取消链,避免客户端断连导致空值被长期缓存
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
defer cancel()
channels, err := s.repo.ListAll(dbCtx)
channels, groupPlatforms, err := s.fetchChannelData(dbCtx)
if err != nil {
return nil, err
}
cache := populateChannelCache(channels, groupPlatforms)
s.cache.Store(cache)
return cache, nil
}
// fetchChannelData 从数据库加载渠道列表和分组平台映射。
func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) {
channels, err := s.repo.ListAll(ctx)
if err != nil {
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err)
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err)
s.storeErrorCache()
return nil, nil, fmt.Errorf("list all channels: %w", err)
}
// 收集所有 groupID,批量查询 platform
var allGroupIDs []int64
for i := range channels {
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
}
groupPlatforms := make(map[int64]string)
if len(allGroupIDs) > 0 {
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs)
if err != nil {
slog.Warn("failed to load group platforms for channel cache", "error", err)
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
s.cache.Store(errorCache)
return nil, fmt.Errorf("get group platforms: %w", err)
s.storeErrorCache()
return nil, nil, fmt.Errorf("get group platforms: %w", err)
}
}
return channels, groupPlatforms, nil
}
// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
cache := newEmptyChannelCache()
cache.groupPlatform = groupPlatforms
cache.byID = make(map[int64]*Channel, len(channels))
......@@ -290,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
for i := range channels {
ch := &channels[i]
cache.byID[ch.ID] = ch
for _, gid := range ch.GroupIDs {
cache.channelByGroupID[gid] = ch
platform := groupPlatforms[gid]
......@@ -298,11 +315,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
expandMappingToCache(cache, ch, gid, platform)
}
}
// 通配符条目保持配置顺序(最先匹配到优先)
s.cache.Store(cache)
return cache, nil
return cache
}
// invalidateCache 使缓存失效,让下次读取时自然重建
......@@ -466,7 +479,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
// 返回 true 表示模型被限制(不在允许列表中)。
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
lk, _ := s.lookupGroupChannel(ctx, groupID)
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err)
}
if lk == nil {
return false
}
......@@ -537,6 +553,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
return newBody
}
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
// Create 和 Update 共用此函数,避免重复。
func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
if err := validateNoConflictingModels(pricing); err != nil {
return err
}
if err := validatePricingIntervals(pricing); err != nil {
return err
}
if err := validateNoConflictingMappings(mapping); err != nil {
return err
}
return validatePricingBillingMode(pricing)
}
// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。
func validatePricingBillingMode(pricing []ChannelModelPricing) error {
for _, p := range pricing {
if err := checkBillingModeRequirements(p); err != nil {
return err
}
if err := checkPricesNotNegative(p); err != nil {
return err
}
if err := checkIntervalsHavePrices(p); err != nil {
return err
}
}
return nil
}
func checkBillingModeRequirements(p ChannelModelPricing) error {
if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage {
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
return infraerrors.BadRequest(
"BILLING_MODE_MISSING_PRICE",
"per-request price or intervals required for per_request/image billing mode",
)
}
}
return nil
}
func checkPricesNotNegative(p ChannelModelPricing) error {
checks := []struct {
field string
val *float64
}{
{"input_price", p.InputPrice},
{"output_price", p.OutputPrice},
{"cache_write_price", p.CacheWritePrice},
{"cache_read_price", p.CacheReadPrice},
{"image_output_price", p.ImageOutputPrice},
{"per_request_price", p.PerRequestPrice},
}
for _, c := range checks {
if c.val != nil && *c.val < 0 {
return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field))
}
}
return nil
}
func checkIntervalsHavePrices(p ChannelModelPricing) error {
for _, iv := range p.Intervals {
if iv.InputPrice == nil && iv.OutputPrice == nil &&
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
iv.PerRequestPrice == nil {
return infraerrors.BadRequest(
"INTERVAL_MISSING_PRICE",
fmt.Sprintf("interval [%d, %s] has no price fields set for model %v",
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models),
)
}
}
return nil
}
func formatMaxTokens(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// --- CRUD ---
// Create 创建渠道
......@@ -549,15 +650,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
return nil, ErrChannelExists
}
// 检查分组冲突
if len(input.GroupIDs) > 0 {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
if err != nil {
return nil, fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return nil, ErrGroupAlreadyInChannel
}
if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil {
return nil, err
}
channel := &Channel{
......@@ -574,13 +668,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
channel.BillingModelSource = BillingModelSourceChannelMapped
}
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
......@@ -604,102 +692,112 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
return nil, fmt.Errorf("get channel: %w", err)
}
if err := s.applyUpdateInput(ctx, channel, input); err != nil {
return nil, err
}
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
oldGroupIDs := s.getOldGroupIDs(ctx, id)
if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err)
}
s.invalidateCache()
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
return s.repo.GetByID(ctx, id)
}
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error {
if input.Name != "" && input.Name != channel.Name {
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID)
if err != nil {
return nil, fmt.Errorf("check channel exists: %w", err)
return fmt.Errorf("check channel exists: %w", err)
}
if exists {
return nil, ErrChannelExists
return ErrChannelExists
}
channel.Name = input.Name
}
if input.Description != nil {
channel.Description = *input.Description
}
if input.Status != "" {
channel.Status = input.Status
}
if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels
}
// 检查分组冲突
if input.GroupIDs != nil {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
if err != nil {
return nil, fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return nil, ErrGroupAlreadyInChannel
if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
return err
}
channel.GroupIDs = *input.GroupIDs
}
if input.ModelPricing != nil {
channel.ModelPricing = *input.ModelPricing
}
if input.ModelMapping != nil {
channel.ModelMapping = input.ModelMapping
}
if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource
}
return nil
}
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。
// channelID 为当前渠道 ID(Create 时传 0)。
func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs)
if err != nil {
return fmt.Errorf("check group conflicts: %w", err)
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err
if len(conflicting) > 0 {
return ErrGroupAlreadyInChannel
}
return nil
}
// 先获取旧分组,Update 后旧分组关联已删除,无法再查到
var oldGroupIDs []int64
if s.authCacheInvalidator != nil {
var err2 error
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
if err2 != nil {
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
}
// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。
func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 {
if s.authCacheInvalidator == nil {
return nil
}
if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err)
oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID)
if err != nil {
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err)
}
return oldGroupIDs
}
s.invalidateCache()
// 失效新旧分组的 auth 缓存
if s.authCacheInvalidator != nil {
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
for _, gid := range oldGroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。
func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) {
if s.authCacheInvalidator == nil {
return
}
seen := make(map[int64]struct{})
for _, ids := range groupIDSets {
for _, gid := range ids {
if _, ok := seen[gid]; ok {
continue
}
for _, gid := range channel.GroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
}
return s.repo.GetByID(ctx, id)
}
// Delete 删除渠道
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
// 先获取关联分组用于失效缓存
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
if err != nil {
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
......@@ -710,12 +808,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
}
s.invalidateCache()
if s.authCacheInvalidator != nil {
for _, gid := range groupIDs {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
s.invalidateAuthCacheForGroups(ctx, groupIDs)
return nil
}
......
......@@ -2199,3 +2199,207 @@ func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) {
require.Equal(t, int64(601), result.ID)
require.InDelta(t, 5e-6, *result.InputPrice, 1e-12)
}
// ---------------------------------------------------------------------------
// 10. ToUsageFields
// ---------------------------------------------------------------------------
func TestToUsageFields_NoMapping(t *testing.T) {
r := ChannelMappingResult{
MappedModel: "claude-opus-4",
ChannelID: 1,
Mapped: false,
BillingModelSource: BillingModelSourceRequested,
}
fields := r.ToUsageFields("claude-opus-4", "claude-opus-4")
require.Equal(t, int64(1), fields.ChannelID)
require.Equal(t, "claude-opus-4", fields.OriginalModel)
require.Equal(t, "claude-opus-4", fields.ChannelMappedModel)
require.Equal(t, BillingModelSourceRequested, fields.BillingModelSource)
require.Empty(t, fields.ModelMappingChain)
}
func TestToUsageFields_WithChannelMapping(t *testing.T) {
r := ChannelMappingResult{
MappedModel: "claude-sonnet-4-20250514",
ChannelID: 2,
Mapped: true,
BillingModelSource: BillingModelSourceChannelMapped,
}
fields := r.ToUsageFields("claude-sonnet-4", "claude-sonnet-4-20250514")
require.Equal(t, int64(2), fields.ChannelID)
require.Equal(t, "claude-sonnet-4", fields.OriginalModel)
require.Equal(t, "claude-sonnet-4-20250514", fields.ChannelMappedModel)
require.Equal(t, "claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain)
}
func TestToUsageFields_WithUpstreamDifference(t *testing.T) {
r := ChannelMappingResult{
MappedModel: "claude-sonnet-4",
ChannelID: 3,
Mapped: true,
BillingModelSource: BillingModelSourceUpstream,
}
fields := r.ToUsageFields("my-alias", "claude-sonnet-4-20250514")
require.Equal(t, "my-alias", fields.OriginalModel)
require.Equal(t, "claude-sonnet-4", fields.ChannelMappedModel)
require.Equal(t, "my-alias→claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain)
}
// ---------------------------------------------------------------------------
// 11. validatePricingBillingMode (moved from handler tests)
// ---------------------------------------------------------------------------
func TestValidatePricingBillingMode(t *testing.T) {
tests := []struct {
name string
pricing []ChannelModelPricing
wantErr bool
errMsg string
}{
{
name: "token mode - valid",
pricing: []ChannelModelPricing{{BillingMode: BillingModeToken}},
},
{
name: "per_request with price - valid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.5),
}},
},
{
name: "per_request with intervals - valid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModePerRequest,
Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000), PerRequestPrice: testPtrFloat64(0.1)}},
}},
},
{
name: "per_request no price no intervals - invalid",
pricing: []ChannelModelPricing{{BillingMode: BillingModePerRequest}},
wantErr: true,
errMsg: "per-request price or intervals required",
},
{
name: "image no price no intervals - invalid",
pricing: []ChannelModelPricing{{BillingMode: BillingModeImage}},
wantErr: true,
errMsg: "per-request price or intervals required",
},
{
name: "empty list - valid",
pricing: []ChannelModelPricing{},
},
{
name: "negative input_price - invalid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(-0.01),
}},
wantErr: true,
errMsg: "input_price must be >= 0",
},
{
name: "interval with no price fields - invalid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.5),
Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000)}},
}},
wantErr: true,
errMsg: "has no price fields set",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePricingBillingMode(tt.pricing)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
} else {
require.NoError(t, err)
}
})
}
}
// ---------------------------------------------------------------------------
// 12. Antigravity wildcard mapping isolation
// ---------------------------------------------------------------------------
func TestResolveChannelMapping_AntigravityDoesNotSeeWildcardMappingFromOtherPlatforms(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10, 20},
ModelMapping: map[string]map[string]string{
PlatformAnthropic: {"claude-*": "claude-override"},
PlatformGemini: {"gemini-*": "gemini-override"},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity, 20: PlatformAnthropic})
svc := newTestChannelService(repo)
// antigravity 分组不应看到 anthropic/gemini 的通配符映射
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
require.False(t, result.Mapped)
require.Equal(t, "claude-opus-4", result.MappedModel)
result = svc.ResolveChannelMapping(context.Background(), 10, "gemini-2.5-pro")
require.False(t, result.Mapped)
require.Equal(t, "gemini-2.5-pro", result.MappedModel)
// anthropic 分组应该能看到 anthropic 的通配符映射
result = svc.ResolveChannelMapping(context.Background(), 20, "claude-opus-4")
require.True(t, result.Mapped)
require.Equal(t, "claude-override", result.MappedModel)
}
// ---------------------------------------------------------------------------
// 13. Create/Update with mapping conflict validation
// ---------------------------------------------------------------------------
func TestCreate_MappingConflict(t *testing.T) {
repo := &mockChannelRepository{}
svc := newTestChannelService(repo)
_, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "test",
ModelMapping: map[string]map[string]string{
PlatformAnthropic: {
"claude-*": "target-a",
"claude-opus-*": "target-b",
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT")
}
func TestUpdate_MappingConflict(t *testing.T) {
existingChannel := &Channel{
ID: 1,
Name: "existing",
Status: StatusActive,
}
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
return existingChannel, nil
},
}
svc := newTestChannelService(repo)
conflictMapping := map[string]map[string]string{
PlatformAnthropic: {
"claude-*": "target-a",
"claude-opus-*": "target-b",
},
}
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
ModelMapping: conflictMapping,
})
require.Error(t, err)
require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT")
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment