Commit 88759407 authored by erio's avatar erio
Browse files

feat(channel): 模型映射源支持通配符匹配

与定价通配符一致,映射源支持 * 后缀通配符(最长前缀优先):
- `*` 匹配所有模型
- `claude-*` 匹配 claude- 开头的模型
- 精确匹配优先于通配符
parent 6c99cc61
...@@ -72,12 +72,19 @@ type wildcardPricingEntry struct { ...@@ -72,12 +72,19 @@ type wildcardPricingEntry struct {
pricing *ChannelModelPricing pricing *ChannelModelPricing
} }
// wildcardMappingEntry 通配符映射条目
type wildcardMappingEntry struct {
prefix string
target string
}
// channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找) // channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找)
type channelCache struct { type channelCache struct {
// 热路径查找 // 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价 pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序) wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序)
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标 mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序)
channelByGroupID map[int64]*Channel // groupID → 渠道 channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform groupPlatform map[int64]string // groupID → platform
...@@ -173,6 +180,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -173,6 +180,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry), wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
mappingByGroupModel: make(map[channelModelKey]string), mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel), channelByGroupID: make(map[int64]*Channel),
groupPlatform: make(map[int64]string), groupPlatform: make(map[int64]string),
byID: make(map[int64]*Channel), byID: make(map[int64]*Channel),
...@@ -200,6 +208,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -200,6 +208,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing),
wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry), wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry),
mappingByGroupModel: make(map[channelModelKey]string), mappingByGroupModel: make(map[channelModelKey]string),
wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry),
channelByGroupID: make(map[int64]*Channel), channelByGroupID: make(map[int64]*Channel),
groupPlatform: groupPlatforms, groupPlatform: groupPlatforms,
byID: make(map[int64]*Channel, len(channels)), byID: make(map[int64]*Channel, len(channels)),
...@@ -240,12 +249,22 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -240,12 +249,22 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
// 只展开该平台的模型映射到 (groupID, platform, model) → target // 只展开该平台的模型映射到 (groupID, platform, model) → target
if platformMapping, ok := ch.ModelMapping[platform]; ok { if platformMapping, ok := ch.ModelMapping[platform]; ok {
for src, dst := range platformMapping { for src, dst := range platformMapping {
if strings.HasSuffix(src, "*") {
// 通配符映射 → 存入 wildcardMappingByGP
prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform}
cache.wildcardMappingByGP[gpKey] = append(cache.wildcardMappingByGP[gpKey], &wildcardMappingEntry{
prefix: prefix,
target: dst,
})
} else {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)} key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)}
cache.mappingByGroupModel[key] = dst cache.mappingByGroupModel[key] = dst
} }
} }
} }
} }
}
// 通配符条目按前缀长度降序排列(最长前缀优先匹配) // 通配符条目按前缀长度降序排列(最长前缀优先匹配)
for gpKey, entries := range cache.wildcardByGroupPlatform { for gpKey, entries := range cache.wildcardByGroupPlatform {
...@@ -254,6 +273,12 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) ...@@ -254,6 +273,12 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
}) })
cache.wildcardByGroupPlatform[gpKey] = entries cache.wildcardByGroupPlatform[gpKey] = entries
} }
for gpKey, entries := range cache.wildcardMappingByGP {
sort.Slice(entries, func(i, j int) bool {
return len(entries[i].prefix) > len(entries[j].prefix)
})
cache.wildcardMappingByGP[gpKey] = entries
}
s.cache.Store(cache) s.cache.Store(cache)
return cache, nil return cache, nil
...@@ -277,6 +302,18 @@ func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) ...@@ -277,6 +302,18 @@ func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string)
return nil return nil
} }
// matchWildcardMapping 在通配符映射中查找匹配项(最长前缀优先)
func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string {
gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform}
wildcards := c.wildcardMappingByGP[gpKey]
for _, wc := range wildcards {
if strings.HasPrefix(modelLower, wc.prefix) {
return wc.target
}
}
return ""
}
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1)) // GetChannelForGroup 获取分组关联的渠道(热路径 O(1))
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) { func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
cache, err := s.loadCache(ctx) cache, err := s.loadCache(ctx)
...@@ -348,6 +385,9 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6 ...@@ -348,6 +385,9 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
if mapped, ok := cache.mappingByGroupModel[key]; ok { if mapped, ok := cache.mappingByGroupModel[key]; ok {
result.MappedModel = mapped result.MappedModel = mapped
result.Mapped = true result.Mapped = true
} else if mapped := cache.matchWildcardMapping(groupID, platform, strings.ToLower(model)); mapped != "" {
result.MappedModel = mapped
result.Mapped = true
} }
return result return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment