Commit dca85c86 authored by erio's avatar erio Committed by 陈曦
Browse files

refactor(channel): split long functions, extract shared validation, move...

refactor(channel): split long functions, extract shared validation, move billing validation to service

- Split Update (98→25 lines), buildCache (54→20 lines), Create (51→25 lines)
  into focused sub-functions: applyUpdateInput, checkGroupConflicts,
  fetchChannelData, populateChannelCache, storeErrorCache, getOldGroupIDs,
  invalidateAuthCacheForGroups
- Extract validateChannelConfig to eliminate duplicated validation calls
  between Create and Update
- Move validatePricingBillingMode from handler to service layer for
  proper separation of concerns
- Add error logging to IsModelRestricted (was silently swallowing errors)
- Add 12 new tests: ToUsageFields, billing mode validation, antigravity
  wildcard mapping isolation, Create/Update mapping conflict integration
parent 70836c70
package admin package admin
import ( import (
"errors"
"fmt"
"strconv" "strconv"
"strings" "strings"
...@@ -235,61 +233,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe ...@@ -235,61 +233,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
return result return result
} }
// validatePricingBillingMode 校验计费配置
func validatePricingBillingMode(pricing []service.ChannelModelPricing) error {
for _, p := range pricing {
// 按次/图片模式必须配置默认价格或区间
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
return errors.New("per-request price or intervals required for per_request/image billing mode")
}
}
// 校验价格不能为负
if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil {
return err
}
if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil {
return err
}
if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil {
return err
}
// 校验 interval:至少有一个价格字段非空
for _, iv := range p.Intervals {
if iv.InputPrice == nil && iv.OutputPrice == nil &&
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
iv.PerRequestPrice == nil {
return fmt.Errorf("interval [%d, %s] has no price fields set for model %v",
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models)
}
}
}
return nil
}
func validatePriceNotNegative(field string, val *float64) error {
if val != nil && *val < 0 {
return fmt.Errorf("%s must be >= 0", field)
}
return nil
}
func formatMaxTokens(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// --- Handlers --- // --- Handlers ---
// List handles listing channels with pagination // List handles listing channels with pagination
...@@ -343,10 +286,6 @@ func (h *ChannelHandler) Create(c *gin.Context) { ...@@ -343,10 +286,6 @@ func (h *ChannelHandler) Create(c *gin.Context) {
} }
pricing := pricingRequestToService(req.ModelPricing) pricing := pricingRequestToService(req.ModelPricing)
if err := validatePricingBillingMode(pricing); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name, Name: req.Name,
...@@ -391,10 +330,6 @@ func (h *ChannelHandler) Update(c *gin.Context) { ...@@ -391,10 +330,6 @@ func (h *ChannelHandler) Update(c *gin.Context) {
} }
if req.ModelPricing != nil { if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing) pricing := pricingRequestToService(*req.ModelPricing)
if err := validatePricingBillingMode(pricing); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
input.ModelPricing = &pricing input.ModelPricing = &pricing
} }
......
...@@ -400,103 +400,3 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) { ...@@ -400,103 +400,3 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) {
require.Nil(t, r.ImageOutputPrice) require.Nil(t, r.ImageOutputPrice)
require.Nil(t, r.PerRequestPrice) require.Nil(t, r.PerRequestPrice)
} }
// ---------------------------------------------------------------------------
// 3. validatePricingBillingMode
// ---------------------------------------------------------------------------
func TestValidatePricingBillingMode(t *testing.T) {
tests := []struct {
name string
pricing []service.ChannelModelPricing
wantErr bool
}{
{
name: "token mode - valid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeToken},
},
wantErr: false,
},
{
name: "per_request with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
},
wantErr: false,
},
{
name: "per_request with intervals - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
Intervals: []service.PricingInterval{
{MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)},
},
},
},
wantErr: false,
},
{
name: "per_request no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModePerRequest},
},
wantErr: true,
},
{
name: "image with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeImage,
PerRequestPrice: float64Ptr(0.2),
},
},
wantErr: false,
},
{
name: "image no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeImage},
},
wantErr: true,
},
{
name: "empty list - valid",
pricing: []service.ChannelModelPricing{},
wantErr: false,
},
{
name: "mixed modes with invalid image - invalid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.01),
},
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
{
BillingMode: service.BillingModeImage,
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePricingBillingMode(tt.pricing)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), "per-request price or intervals required")
} else {
require.NoError(t, err)
}
})
}
}
...@@ -248,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform ...@@ -248,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
} }
} }
// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。
// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。
func (s *ChannelService) storeErrorCache() {
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
s.cache.Store(errorCache)
}
// buildCache 从数据库构建渠道缓存。 // buildCache 从数据库构建渠道缓存。
// 使用独立 context 避免请求取消导致空值被长期缓存。 // 使用独立 context 避免请求取消导致空值被长期缓存。
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) { func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
// 断开请求取消链,避免客户端断连导致空值被长期缓存
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout) dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
defer cancel() defer cancel()
channels, err := s.repo.ListAll(dbCtx) channels, groupPlatforms, err := s.fetchChannelData(dbCtx)
if err != nil {
return nil, err
}
cache := populateChannelCache(channels, groupPlatforms)
s.cache.Store(cache)
return cache, nil
}
// fetchChannelData 从数据库加载渠道列表和分组平台映射。
func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) {
channels, err := s.repo.ListAll(ctx)
if err != nil { if err != nil {
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err) slog.Warn("failed to build channel cache", "error", err)
errorCache := newEmptyChannelCache() s.storeErrorCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL return nil, nil, fmt.Errorf("list all channels: %w", err)
s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err)
} }
// 收集所有 groupID,批量查询 platform
var allGroupIDs []int64 var allGroupIDs []int64
for i := range channels { for i := range channels {
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...) allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
} }
groupPlatforms := make(map[int64]string) groupPlatforms := make(map[int64]string)
if len(allGroupIDs) > 0 { if len(allGroupIDs) > 0 {
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs) groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs)
if err != nil { if err != nil {
slog.Warn("failed to load group platforms for channel cache", "error", err) slog.Warn("failed to load group platforms for channel cache", "error", err)
errorCache := newEmptyChannelCache() s.storeErrorCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) return nil, nil, fmt.Errorf("get group platforms: %w", err)
s.cache.Store(errorCache)
return nil, fmt.Errorf("get group platforms: %w", err)
} }
} }
return channels, groupPlatforms, nil
}
// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
cache := newEmptyChannelCache() cache := newEmptyChannelCache()
cache.groupPlatform = groupPlatforms cache.groupPlatform = groupPlatforms
cache.byID = make(map[int64]*Channel, len(channels)) cache.byID = make(map[int64]*Channel, len(channels))
...@@ -290,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -290,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
for i := range channels { for i := range channels {
ch := &channels[i] ch := &channels[i]
cache.byID[ch.ID] = ch cache.byID[ch.ID] = ch
for _, gid := range ch.GroupIDs { for _, gid := range ch.GroupIDs {
cache.channelByGroupID[gid] = ch cache.channelByGroupID[gid] = ch
platform := groupPlatforms[gid] platform := groupPlatforms[gid]
...@@ -298,11 +315,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -298,11 +315,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
expandMappingToCache(cache, ch, gid, platform) expandMappingToCache(cache, ch, gid, platform)
} }
} }
return cache
// 通配符条目保持配置顺序(最先匹配到优先)
s.cache.Store(cache)
return cache, nil
} }
// invalidateCache 使缓存失效,让下次读取时自然重建 // invalidateCache 使缓存失效,让下次读取时自然重建
...@@ -466,7 +479,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6 ...@@ -466,7 +479,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
// 返回 true 表示模型被限制(不在允许列表中)。 // 返回 true 表示模型被限制(不在允许列表中)。
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。 // 如果渠道未启用模型限制或分组无渠道关联,返回 false。
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
lk, _ := s.lookupGroupChannel(ctx, groupID) lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err)
}
if lk == nil { if lk == nil {
return false return false
} }
...@@ -537,6 +553,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte { ...@@ -537,6 +553,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
return newBody return newBody
} }
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
// Create 和 Update 共用此函数,避免重复。
func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
if err := validateNoConflictingModels(pricing); err != nil {
return err
}
if err := validatePricingIntervals(pricing); err != nil {
return err
}
if err := validateNoConflictingMappings(mapping); err != nil {
return err
}
return validatePricingBillingMode(pricing)
}
// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。
func validatePricingBillingMode(pricing []ChannelModelPricing) error {
for _, p := range pricing {
if err := checkBillingModeRequirements(p); err != nil {
return err
}
if err := checkPricesNotNegative(p); err != nil {
return err
}
if err := checkIntervalsHavePrices(p); err != nil {
return err
}
}
return nil
}
func checkBillingModeRequirements(p ChannelModelPricing) error {
if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage {
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
return infraerrors.BadRequest(
"BILLING_MODE_MISSING_PRICE",
"per-request price or intervals required for per_request/image billing mode",
)
}
}
return nil
}
func checkPricesNotNegative(p ChannelModelPricing) error {
checks := []struct {
field string
val *float64
}{
{"input_price", p.InputPrice},
{"output_price", p.OutputPrice},
{"cache_write_price", p.CacheWritePrice},
{"cache_read_price", p.CacheReadPrice},
{"image_output_price", p.ImageOutputPrice},
{"per_request_price", p.PerRequestPrice},
}
for _, c := range checks {
if c.val != nil && *c.val < 0 {
return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field))
}
}
return nil
}
func checkIntervalsHavePrices(p ChannelModelPricing) error {
for _, iv := range p.Intervals {
if iv.InputPrice == nil && iv.OutputPrice == nil &&
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
iv.PerRequestPrice == nil {
return infraerrors.BadRequest(
"INTERVAL_MISSING_PRICE",
fmt.Sprintf("interval [%d, %s] has no price fields set for model %v",
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models),
)
}
}
return nil
}
func formatMaxTokens(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// --- CRUD --- // --- CRUD ---
// Create 创建渠道 // Create 创建渠道
...@@ -549,15 +650,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) ...@@ -549,15 +650,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
return nil, ErrChannelExists return nil, ErrChannelExists
} }
// 检查分组冲突 if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil {
if len(input.GroupIDs) > 0 { return nil, err
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
if err != nil {
return nil, fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return nil, ErrGroupAlreadyInChannel
}
} }
channel := &Channel{ channel := &Channel{
...@@ -574,13 +668,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) ...@@ -574,13 +668,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
channel.BillingModelSource = BillingModelSourceChannelMapped channel.BillingModelSource = BillingModelSourceChannelMapped
} }
if err := validateNoConflictingModels(channel.ModelPricing); err != nil { if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err return nil, err
} }
...@@ -604,102 +692,112 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan ...@@ -604,102 +692,112 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
return nil, fmt.Errorf("get channel: %w", err) return nil, fmt.Errorf("get channel: %w", err)
} }
if err := s.applyUpdateInput(ctx, channel, input); err != nil {
return nil, err
}
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
oldGroupIDs := s.getOldGroupIDs(ctx, id)
if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err)
}
s.invalidateCache()
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
return s.repo.GetByID(ctx, id)
}
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error {
if input.Name != "" && input.Name != channel.Name { if input.Name != "" && input.Name != channel.Name {
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id) exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("check channel exists: %w", err) return fmt.Errorf("check channel exists: %w", err)
} }
if exists { if exists {
return nil, ErrChannelExists return ErrChannelExists
} }
channel.Name = input.Name channel.Name = input.Name
} }
if input.Description != nil { if input.Description != nil {
channel.Description = *input.Description channel.Description = *input.Description
} }
if input.Status != "" { if input.Status != "" {
channel.Status = input.Status channel.Status = input.Status
} }
if input.RestrictModels != nil { if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels channel.RestrictModels = *input.RestrictModels
} }
// 检查分组冲突
if input.GroupIDs != nil { if input.GroupIDs != nil {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs) if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
if err != nil { return err
return nil, fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return nil, ErrGroupAlreadyInChannel
} }
channel.GroupIDs = *input.GroupIDs channel.GroupIDs = *input.GroupIDs
} }
if input.ModelPricing != nil { if input.ModelPricing != nil {
channel.ModelPricing = *input.ModelPricing channel.ModelPricing = *input.ModelPricing
} }
if input.ModelMapping != nil { if input.ModelMapping != nil {
channel.ModelMapping = input.ModelMapping channel.ModelMapping = input.ModelMapping
} }
if input.BillingModelSource != "" { if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource channel.BillingModelSource = input.BillingModelSource
} }
return nil
}
if err := validateNoConflictingModels(channel.ModelPricing); err != nil { // checkGroupConflicts 检查待关联的分组是否已属于其他渠道。
return nil, err // channelID 为当前渠道 ID(Create 时传 0)。
func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
} }
if err := validatePricingIntervals(channel.ModelPricing); err != nil { conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs)
return nil, err if err != nil {
return fmt.Errorf("check group conflicts: %w", err)
} }
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { if len(conflicting) > 0 {
return nil, err return ErrGroupAlreadyInChannel
} }
return nil
}
// 先获取旧分组,Update 后旧分组关联已删除,无法再查到 // getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。
var oldGroupIDs []int64 func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 {
if s.authCacheInvalidator != nil { if s.authCacheInvalidator == nil {
var err2 error return nil
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
if err2 != nil {
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
}
} }
oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID)
if err := s.repo.Update(ctx, channel); err != nil { if err != nil {
return nil, fmt.Errorf("update channel: %w", err) slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err)
} }
return oldGroupIDs
}
s.invalidateCache() // invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。
func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) {
// 失效新旧分组的 auth 缓存 if s.authCacheInvalidator == nil {
if s.authCacheInvalidator != nil { return
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs)) }
for _, gid := range oldGroupIDs { seen := make(map[int64]struct{})
if _, ok := seen[gid]; !ok { for _, ids := range groupIDSets {
seen[gid] = struct{}{} for _, gid := range ids {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) if _, ok := seen[gid]; ok {
} continue
}
for _, gid := range channel.GroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
} }
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
} }
} }
return s.repo.GetByID(ctx, id)
} }
// Delete 删除渠道 // Delete 删除渠道
func (s *ChannelService) Delete(ctx context.Context, id int64) error { func (s *ChannelService) Delete(ctx context.Context, id int64) error {
// 先获取关联分组用于失效缓存
groupIDs, err := s.repo.GetGroupIDs(ctx, id) groupIDs, err := s.repo.GetGroupIDs(ctx, id)
if err != nil { if err != nil {
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err) slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
...@@ -710,12 +808,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error { ...@@ -710,12 +808,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
} }
s.invalidateCache() s.invalidateCache()
s.invalidateAuthCacheForGroups(ctx, groupIDs)
if s.authCacheInvalidator != nil {
for _, gid := range groupIDs {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
return nil return nil
} }
......
...@@ -2199,3 +2199,207 @@ func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) { ...@@ -2199,3 +2199,207 @@ func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) {
require.Equal(t, int64(601), result.ID) require.Equal(t, int64(601), result.ID)
require.InDelta(t, 5e-6, *result.InputPrice, 1e-12) require.InDelta(t, 5e-6, *result.InputPrice, 1e-12)
} }
// ---------------------------------------------------------------------------
// 10. ToUsageFields
// ---------------------------------------------------------------------------
func TestToUsageFields_NoMapping(t *testing.T) {
r := ChannelMappingResult{
MappedModel: "claude-opus-4",
ChannelID: 1,
Mapped: false,
BillingModelSource: BillingModelSourceRequested,
}
fields := r.ToUsageFields("claude-opus-4", "claude-opus-4")
require.Equal(t, int64(1), fields.ChannelID)
require.Equal(t, "claude-opus-4", fields.OriginalModel)
require.Equal(t, "claude-opus-4", fields.ChannelMappedModel)
require.Equal(t, BillingModelSourceRequested, fields.BillingModelSource)
require.Empty(t, fields.ModelMappingChain)
}
func TestToUsageFields_WithChannelMapping(t *testing.T) {
r := ChannelMappingResult{
MappedModel: "claude-sonnet-4-20250514",
ChannelID: 2,
Mapped: true,
BillingModelSource: BillingModelSourceChannelMapped,
}
fields := r.ToUsageFields("claude-sonnet-4", "claude-sonnet-4-20250514")
require.Equal(t, int64(2), fields.ChannelID)
require.Equal(t, "claude-sonnet-4", fields.OriginalModel)
require.Equal(t, "claude-sonnet-4-20250514", fields.ChannelMappedModel)
require.Equal(t, "claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain)
}
func TestToUsageFields_WithUpstreamDifference(t *testing.T) {
r := ChannelMappingResult{
MappedModel: "claude-sonnet-4",
ChannelID: 3,
Mapped: true,
BillingModelSource: BillingModelSourceUpstream,
}
fields := r.ToUsageFields("my-alias", "claude-sonnet-4-20250514")
require.Equal(t, "my-alias", fields.OriginalModel)
require.Equal(t, "claude-sonnet-4", fields.ChannelMappedModel)
require.Equal(t, "my-alias→claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain)
}
// ---------------------------------------------------------------------------
// 11. validatePricingBillingMode (moved from handler tests)
// ---------------------------------------------------------------------------
func TestValidatePricingBillingMode(t *testing.T) {
tests := []struct {
name string
pricing []ChannelModelPricing
wantErr bool
errMsg string
}{
{
name: "token mode - valid",
pricing: []ChannelModelPricing{{BillingMode: BillingModeToken}},
},
{
name: "per_request with price - valid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.5),
}},
},
{
name: "per_request with intervals - valid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModePerRequest,
Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000), PerRequestPrice: testPtrFloat64(0.1)}},
}},
},
{
name: "per_request no price no intervals - invalid",
pricing: []ChannelModelPricing{{BillingMode: BillingModePerRequest}},
wantErr: true,
errMsg: "per-request price or intervals required",
},
{
name: "image no price no intervals - invalid",
pricing: []ChannelModelPricing{{BillingMode: BillingModeImage}},
wantErr: true,
errMsg: "per-request price or intervals required",
},
{
name: "empty list - valid",
pricing: []ChannelModelPricing{},
},
{
name: "negative input_price - invalid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(-0.01),
}},
wantErr: true,
errMsg: "input_price must be >= 0",
},
{
name: "interval with no price fields - invalid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.5),
Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000)}},
}},
wantErr: true,
errMsg: "has no price fields set",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePricingBillingMode(tt.pricing)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
} else {
require.NoError(t, err)
}
})
}
}
// ---------------------------------------------------------------------------
// 12. Antigravity wildcard mapping isolation
// ---------------------------------------------------------------------------
func TestResolveChannelMapping_AntigravityDoesNotSeeWildcardMappingFromOtherPlatforms(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10, 20},
ModelMapping: map[string]map[string]string{
PlatformAnthropic: {"claude-*": "claude-override"},
PlatformGemini: {"gemini-*": "gemini-override"},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity, 20: PlatformAnthropic})
svc := newTestChannelService(repo)
// antigravity 分组不应看到 anthropic/gemini 的通配符映射
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
require.False(t, result.Mapped)
require.Equal(t, "claude-opus-4", result.MappedModel)
result = svc.ResolveChannelMapping(context.Background(), 10, "gemini-2.5-pro")
require.False(t, result.Mapped)
require.Equal(t, "gemini-2.5-pro", result.MappedModel)
// anthropic 分组应该能看到 anthropic 的通配符映射
result = svc.ResolveChannelMapping(context.Background(), 20, "claude-opus-4")
require.True(t, result.Mapped)
require.Equal(t, "claude-override", result.MappedModel)
}
// ---------------------------------------------------------------------------
// 13. Create/Update with mapping conflict validation
// ---------------------------------------------------------------------------
func TestCreate_MappingConflict(t *testing.T) {
repo := &mockChannelRepository{}
svc := newTestChannelService(repo)
_, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "test",
ModelMapping: map[string]map[string]string{
PlatformAnthropic: {
"claude-*": "target-a",
"claude-opus-*": "target-b",
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT")
}
func TestUpdate_MappingConflict(t *testing.T) {
existingChannel := &Channel{
ID: 1,
Name: "existing",
Status: StatusActive,
}
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
return existingChannel, nil
},
}
svc := newTestChannelService(repo)
conflictMapping := map[string]map[string]string{
PlatformAnthropic: {
"claude-*": "target-a",
"claude-opus-*": "target-b",
},
}
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
ModelMapping: conflictMapping,
})
require.Error(t, err)
require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT")
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment