Commit 6de1d0cb authored by erio's avatar erio
Browse files

refactor: split buildCache into sub-functions, reduce nesting 5→2

- Extract newEmptyChannelCache() factory to deduplicate map init
- Extract expandPricingToCache() for model pricing expansion
- Extract expandMappingToCache() for model mapping expansion
- buildCache reduced from 110 to 50 lines
parent 6c718578
...@@ -183,6 +183,67 @@ func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) { ...@@ -183,6 +183,67 @@ func (s *ChannelService) loadCache(ctx context.Context) (*channelCache, error) {
return cache, nil return cache, nil
} }
// newEmptyChannelCache 创建空的渠道缓存(所有 map 已初始化)
func newEmptyChannelCache() *channelCache {
return &channelCache{
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel),
}
}
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
if !isPlatformPricingMatch(platform, pricing.Platform) {
continue // 跳过非本平台的定价
}
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
for _, model := range pricing.Models {
if strings.HasSuffix(model, "*") {
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
prefix: prefix,
pricing: pricing,
})
} else {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing
}
}
}
}
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型。
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for _, mappingPlatform := range matchingPlatforms(platform) {
platformMapping, ok := ch.ModelMapping[mappingPlatform]
if !ok {
continue
}
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
for src, dst := range platformMapping {
if strings.HasSuffix(src, "*") {
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{
prefix: prefix,
target: dst,
})
} else {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)}
cache.mappingByGroupModel[key] = dst
}
}
}
}
// buildCache 从数据库构建渠道缓存。 // buildCache 从数据库构建渠道缓存。
// 使用独立 context 避免请求取消导致空值被长期缓存。 // 使用独立 context 避免请求取消导致空值被长期缓存。
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) { func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
...@@ -194,16 +255,8 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -194,16 +255,8 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
if err != nil { if err != nil {
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err) slog.Warn("failed to build channel cache", "error", err)
errorCache := &channelCache{ errorCache := newEmptyChannelCache()
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel),
loadedAt: time.Now().Add(-(channelCacheTTL - channelErrorTTL)), // 使剩余 TTL = errorTTL
}
s.cache.Store(errorCache) s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err) return nil, fmt.Errorf("list all channels: %w", err)
} }
...@@ -222,71 +275,20 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -222,71 +275,20 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
} }
} }
cache := &channelCache{ cache := newEmptyChannelCache()
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), cache.groupPlatform = groupPlatforms
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry), cache.byID = make(map[int64]*Channel, len(channels))
mappingByGroupModel: make(map[channelModelKey]string), cache.loadedAt = time.Now()
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel),
groupPlatform: groupPlatforms,
byID: make(map[int64]*Channel, len(channels)),
loadedAt: time.Now(),
}
for i := range channels { for i := range channels {
ch := &channels[i] ch := &channels[i]
cache.byID[ch.ID] = ch cache.byID[ch.ID] = ch
// 展开到分组维度
for _, gid := range ch.GroupIDs { for _, gid := range ch.GroupIDs {
cache.channelByGroupID[gid] = ch cache.channelByGroupID[gid] = ch
platform := groupPlatforms[gid] // e.g. "anthropic" platform := groupPlatforms[gid]
expandPricingToCache(cache, ch, gid, platform)
// 只展开该平台的模型定价到 (groupID, platform, model) → *ChannelModelPricing expandMappingToCache(cache, ch, gid, platform)
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目
for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j]
if !isPlatformPricingMatch(platform, pricing.Platform) {
continue // 跳过非本平台的定价
}
for _, model := range pricing.Models {
if strings.HasSuffix(model, "*") {
// 通配符模型 → 存入 wildcardByGroupPlatform
prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{
prefix: prefix,
pricing: pricing,
})
} else {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing
}
}
}
// 只展开该平台的模型映射到 (groupID, platform, model) → target
// antigravity 平台同时服务 Claude 和 Gemini 模型
for _, mappingPlatform := range matchingPlatforms(platform) {
platformMapping, ok := ch.ModelMapping[mappingPlatform]
if !ok {
continue
}
for src, dst := range platformMapping {
if strings.HasSuffix(src, "*") {
// 通配符映射 → 存入 wildcardMappingByGP
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{
prefix: prefix,
target: dst,
})
} else {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)}
cache.mappingByGroupModel[key] = dst
}
}
}
} }
} }
...@@ -362,26 +364,48 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) ...@@ -362,26 +364,48 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
return ch.Clone(), nil return ch.Clone(), nil
} }
// channelLookup 热路径公共查找结果
type channelLookup struct {
cache *channelCache
channel *Channel
platform string
}
// lookupGroupChannel 加载缓存并查找分组对应的渠道信息(公共热路径前置逻辑)。
// 返回 nil 且 err==nil 表示分组无活跃渠道;err!=nil 表示缓存加载失败。
func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) (*channelLookup, error) {
cache, err := s.loadCache(ctx)
if err != nil {
return nil, err
}
ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() {
return nil, nil
}
return &channelLookup{
cache: cache,
channel: ch,
platform: cache.groupPlatform[groupID],
}, nil
}
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1)) // GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing { func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
cache, err := s.loadCache(ctx) lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil { if err != nil {
slog.Warn("failed to load channel cache", "group_id", groupID, "error", err) slog.Warn("failed to load channel cache", "group_id", groupID, "error", err)
return nil return nil
} }
if lk == nil {
// 检查渠道是否启用
ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() {
return nil return nil
} }
platform := cache.groupPlatform[groupID] modelLower := strings.ToLower(model)
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)} key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
pricing, ok := cache.pricingByGroupModel[key] pricing, ok := lk.cache.pricingByGroupModel[key]
if !ok { if !ok {
// 精确查找失败,尝试通配符匹配 // 精确查找失败,尝试通配符匹配
pricing = cache.matchWildcard(groupID, platform, strings.ToLower(model)) pricing = lk.cache.matchWildcard(groupID, lk.platform, modelLower)
if pricing == nil { if pricing == nil {
return nil return nil
} }
...@@ -394,31 +418,57 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int ...@@ -394,31 +418,57 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
// ResolveChannelMapping 解析渠道级模型映射(热路径 O(1)) // ResolveChannelMapping 解析渠道级模型映射(热路径 O(1))
// 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。 // 返回映射结果,包含映射后的模型名、渠道 ID、计费模型来源。
func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult {
cache, err := s.loadCache(ctx) lk, _ := s.lookupGroupChannel(ctx, groupID)
if err != nil { if lk == nil {
return ChannelMappingResult{MappedModel: model} return ChannelMappingResult{MappedModel: model}
} }
return resolveMapping(lk, groupID, model)
}
ch, ok := cache.channelByGroupID[groupID] // IsModelRestricted 检查模型是否被渠道限制。
if !ok || !ch.IsActive() { // 返回 true 表示模型被限制(不在允许列表中)。
return ChannelMappingResult{MappedModel: model} // 如果渠道未启用模型限制或分组无渠道关联,返回 false。
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
lk, _ := s.lookupGroupChannel(ctx, groupID)
if lk == nil {
return false
} }
return checkRestricted(lk, groupID, model)
}
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。
// 返回映射结果和是否被限制。groupID 为 nil 时跳过。
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
if groupID == nil {
return ChannelMappingResult{MappedModel: model}, false
}
lk, _ := s.lookupGroupChannel(ctx, *groupID)
if lk == nil {
return ChannelMappingResult{MappedModel: model}, false
}
// 先用原始模型检查定价列表限制,再做映射
restricted := checkRestricted(lk, *groupID, model)
mapping := resolveMapping(lk, *groupID, model)
return mapping, restricted
}
// resolveMapping 基于已查找的渠道信息解析模型映射
func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult {
result := ChannelMappingResult{ result := ChannelMappingResult{
MappedModel: model, MappedModel: model,
ChannelID: ch.ID, ChannelID: lk.channel.ID,
BillingModelSource: ch.BillingModelSource, BillingModelSource: lk.channel.BillingModelSource,
} }
if result.BillingModelSource == "" { if result.BillingModelSource == "" {
result.BillingModelSource = BillingModelSourceChannelMapped result.BillingModelSource = BillingModelSourceChannelMapped
} }
platform := cache.groupPlatform[groupID] modelLower := strings.ToLower(model)
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)} key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
if mapped, ok := cache.mappingByGroupModel[key]; ok { if mapped, ok := lk.cache.mappingByGroupModel[key]; ok {
result.MappedModel = mapped result.MappedModel = mapped
result.Mapped = true result.Mapped = true
} else if mapped := cache.matchWildcardMapping(groupID, platform, strings.ToLower(model)); mapped != "" { } else if mapped := lk.cache.matchWildcardMapping(groupID, lk.platform, modelLower); mapped != "" {
result.MappedModel = mapped result.MappedModel = mapped
result.Mapped = true result.Mapped = true
} }
...@@ -426,48 +476,24 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6 ...@@ -426,48 +476,24 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
return result return result
} }
// IsModelRestricted 检查模型是否被渠道限制。 // checkRestricted 基于已查找的渠道信息检查模型是否被限制
// 返回 true 表示模型被限制(不在允许列表中)。 func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。 if !lk.channel.RestrictModels {
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
cache, err := s.loadCache(ctx)
if err != nil {
return false // 缓存加载失败时不限制
}
ch, ok := cache.channelByGroupID[groupID]
if !ok || !ch.IsActive() || !ch.RestrictModels {
return false return false
} }
// 检查模型是否在定价列表中 // 检查模型是否在定价列表中
platform := cache.groupPlatform[groupID] modelLower := strings.ToLower(model)
key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)} key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower}
_, exists := cache.pricingByGroupModel[key] if _, exists := lk.cache.pricingByGroupModel[key]; exists {
if exists {
return false return false
} }
// 精确查找失败,尝试通配符匹配 // 精确查找失败,尝试通配符匹配
if cache.matchWildcard(groupID, platform, strings.ToLower(model)) != nil { if lk.cache.matchWildcard(groupID, lk.platform, modelLower) != nil {
return false return false
} }
return true return true
} }
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。
// 返回映射结果和是否被限制。groupID 为 nil 时跳过。
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
var mapping ChannelMappingResult
mapping.MappedModel = model
if groupID == nil {
return mapping, false
}
// 先用原始模型检查定价列表限制,再做映射
restricted := s.IsModelRestricted(ctx, *groupID, model)
mapping = s.ResolveChannelMapping(ctx, *groupID, model)
return mapping, restricted
}
// ReplaceModelInBody 替换请求体 JSON 中的 model 字段。 // ReplaceModelInBody 替换请求体 JSON 中的 model 字段。
func ReplaceModelInBody(body []byte, newModel string) []byte { func ReplaceModelInBody(body []byte, newModel string) []byte {
if len(body) == 0 { if len(body) == 0 {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment