Commit 71f61bbc authored by erio's avatar erio
Browse files

fix: resolve 5 audit findings in channel/credits/scheduling

P0-1: Credits degraded response retry + fail-open
- Add isAntigravityDegradedResponse() to detect transient API failures
- Retry up to 3 times with exponential backoff (500ms/1s/2s)
- Invalidate singleflight cache between retries
- Fail-open after exhausting retries instead of 5h circuit break

P1-1: Fix channel restriction pre-check timing conflict
- Swap checkClaudeCodeRestriction before checkChannelPricingRestriction
- Ensures channel restriction is checked against final fallback groupID

P1-2: Add interval pricing validation (frontend + backend)
- Backend: ValidateIntervals() with boundary, price, overlap checks
- Frontend: validateIntervals() with Chinese error messages
- Rules: MinTokens>=0, MaxTokens>MinTokens, prices>=0, no overlap

P2: Fix cross-platform same-model pricing/mapping override
- Store cache keys using original platform instead of group platform
- Lookup across matching platforms (antigravity→anthropic→gemini)
- Prevents anthropic/gemini same-name models from overwriting each other
parent 6d3ea64a
...@@ -855,6 +855,13 @@ func (s *AccountUsageService) GetAntigravityCredits(ctx context.Context, account ...@@ -855,6 +855,13 @@ func (s *AccountUsageService) GetAntigravityCredits(ctx context.Context, account
return s.getAntigravityUsage(ctx, account) return s.getAntigravityUsage(ctx, account)
} }
// InvalidateAntigravityCreditsCache 清除指定账号的 Antigravity 用量缓存,
// 使下次调用 GetAntigravityCredits 时强制重新拉取。
// 用于 credits 降级响应重试场景:避免重试命中同一个降级缓存。
func (s *AccountUsageService) InvalidateAntigravityCreditsCache(accountID int64) {
s.cache.antigravityCache.Delete(accountID)
}
// recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds // recalcAntigravityRemainingSeconds 重新计算 Antigravity UsageInfo 中各窗口的 RemainingSeconds
// 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数 // 用于从缓存取出时更新倒计时,避免返回过时的剩余秒数
func recalcAntigravityRemainingSeconds(info *UsageInfo) { func recalcAntigravityRemainingSeconds(info *UsageInfo) {
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"io" "io"
"log/slog"
"net/http" "net/http"
"strings" "strings"
"time" "time"
...@@ -17,33 +18,116 @@ const ( ...@@ -17,33 +18,116 @@ const (
// 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。 // 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。
creditsExhaustedKey = "AICredits" creditsExhaustedKey = "AICredits"
creditsExhaustedDuration = 5 * time.Hour creditsExhaustedDuration = 5 * time.Hour
// credits 降级响应重试参数
creditsRetryMaxAttempts = 3
creditsRetryBaseInterval = 500 * time.Millisecond
) )
// creditsRetryableErrorCodes 是降级响应中可重试的错误码集合。
// forbidden 是稳定的封号状态,不属于可恢复的瞬态错误,不重试。
var creditsRetryableErrorCodes = map[string]bool{
errorCodeUnauthenticated: true,
errorCodeRateLimited: true,
errorCodeNetworkError: true,
}
// isAntigravityDegradedResponse 检查 UsageInfo 是否为可重试的降级响应。
// 仅检测 3 个瞬态错误码(unauthenticated/rate_limited/network_error),
// forbidden 是稳定的封号状态,不属于降级。
func isAntigravityDegradedResponse(info *UsageInfo) bool {
if info == nil || info.ErrorCode == "" {
return false
}
return creditsRetryableErrorCodes[info.ErrorCode]
}
// checkAccountCredits 通过共享的 AccountUsageService 缓存检查账号是否有足够的 AI Credits。 // checkAccountCredits 通过共享的 AccountUsageService 缓存检查账号是否有足够的 AI Credits。
// 缓存 TTL 不足时会自动从 Google loadCodeAssist API 刷新。 // 缓存 TTL 不足时会自动从 Google loadCodeAssist API 刷新。
// 返回 true 表示积分可用 // 检测到降级响应时会清除缓存并重试,最终 fail-open(返回 true)
func (s *AntigravityGatewayService) checkAccountCredits( func (s *AntigravityGatewayService) checkAccountCredits(
ctx context.Context, account *Account, ctx context.Context, account *Account,
) bool { ) bool {
if account == nil || account.ID == 0 { if account == nil || account.ID == 0 {
return false return false
} }
if s.accountUsageService == nil { if s.accountUsageService == nil {
return true // 无 usage service 时不阻断 return true // 无 usage service 时不阻断
} }
usageInfo, err := s.accountUsageService.GetAntigravityCredits(ctx, account) usageInfo, err := s.accountUsageService.GetAntigravityCredits(ctx, account)
if err != nil { if err != nil {
logger.LegacyPrintf("service.antigravity_gateway", slog.Error("check_credits: get_credits_failed",
"check_credits: get_credits_failed account=%d err=%v", account.ID, err) "account_id", account.ID, "error", err)
return true // 出错时假设有积分,不阻断 return true // 出错时 fail-open
}
// 非降级响应:直接检查积分余额
if !isAntigravityDegradedResponse(usageInfo) {
return s.logCreditsResult(account, usageInfo)
} }
hasCredits := hasEnoughCredits(usageInfo) // 降级响应:清除缓存后重试
return s.retryCreditsOnDegraded(ctx, account, usageInfo)
}
// retryCreditsOnDegraded 在检测到降级响应后,清除缓存并重试获取 credits。
// 使用指数退避(500ms → 1s → 2s),最多重试 creditsRetryMaxAttempts 次。
// 所有重试失败后 fail-open(返回 true),不做熔断。
func (s *AntigravityGatewayService) retryCreditsOnDegraded(
ctx context.Context, account *Account, lastInfo *UsageInfo,
) bool {
for attempt := 1; attempt <= creditsRetryMaxAttempts; attempt++ {
delay := creditsRetryBaseInterval << (attempt - 1) // 指数退避:500ms, 1s, 2s
slog.Warn("check_credits: degraded response, retrying",
"account_id", account.ID,
"attempt", attempt,
"max_attempts", creditsRetryMaxAttempts,
"error_code", lastInfo.ErrorCode,
"delay", delay,
)
select {
case <-ctx.Done():
slog.Warn("check_credits: context cancelled during retry, fail-open",
"account_id", account.ID, "attempt", attempt)
return true
case <-time.After(delay):
}
// 清除缓存,强制下次 GetAntigravityCredits 重新拉取
s.accountUsageService.InvalidateAntigravityCreditsCache(account.ID)
info, err := s.accountUsageService.GetAntigravityCredits(ctx, account)
if err != nil {
slog.Error("check_credits: retry get_credits_failed",
"account_id", account.ID, "attempt", attempt, "error", err)
continue
}
// 重试成功(不再是降级响应):检查积分余额
if !isAntigravityDegradedResponse(info) {
slog.Info("check_credits: retry succeeded",
"account_id", account.ID, "attempt", attempt)
return s.logCreditsResult(account, info)
}
lastInfo = info
}
// 所有重试失败:fail-open,不做熔断
slog.Warn("check_credits: all retries exhausted, fail-open",
"account_id", account.ID,
"last_error_code", lastInfo.ErrorCode,
)
return true
}
// logCreditsResult 检查积分并记录不足日志,返回是否有积分。
func (s *AntigravityGatewayService) logCreditsResult(account *Account, info *UsageInfo) bool {
hasCredits := hasEnoughCredits(info)
if !hasCredits { if !hasCredits {
logger.LegacyPrintf("service.antigravity_gateway", slog.Warn("check_credits: insufficient credits",
"check_credits: account=%d has_credits=false", account.ID) "account_id", account.ID)
} }
return hasCredits return hasCredits
} }
......
package service package service
import ( import (
"fmt"
"sort"
"strings" "strings"
"time" "time"
) )
...@@ -177,6 +179,94 @@ func (c *Channel) Clone() *Channel { ...@@ -177,6 +179,94 @@ func (c *Channel) Clone() *Channel {
return &cp return &cp
} }
// ValidateIntervals 校验区间列表的合法性。
// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens;
// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义);
// 无界区间(MaxTokens=nil)必须是最后一个。间隙允许(回退默认价格)。
func ValidateIntervals(intervals []PricingInterval) error {
if len(intervals) == 0 {
return nil
}
sorted := make([]PricingInterval, len(intervals))
copy(sorted, intervals)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].MinTokens < sorted[j].MinTokens
})
for i := range sorted {
if err := validateSingleInterval(&sorted[i], i); err != nil {
return err
}
}
return validateIntervalOverlap(sorted)
}
// validateSingleInterval 校验单个区间的字段合法性
func validateSingleInterval(iv *PricingInterval, idx int) error {
if iv.MinTokens < 0 {
return fmt.Errorf("interval #%d: min_tokens (%d) must be >= 0", idx+1, iv.MinTokens)
}
if iv.MaxTokens != nil {
if *iv.MaxTokens <= 0 {
return fmt.Errorf("interval #%d: max_tokens (%d) must be > 0", idx+1, *iv.MaxTokens)
}
if *iv.MaxTokens <= iv.MinTokens {
return fmt.Errorf("interval #%d: max_tokens (%d) must be > min_tokens (%d)",
idx+1, *iv.MaxTokens, iv.MinTokens)
}
}
return validateIntervalPrices(iv, idx)
}
// validateIntervalPrices 校验区间内所有价格字段 >= 0
func validateIntervalPrices(iv *PricingInterval, idx int) error {
prices := []struct {
name string
val *float64
}{
{"input_price", iv.InputPrice},
{"output_price", iv.OutputPrice},
{"cache_write_price", iv.CacheWritePrice},
{"cache_read_price", iv.CacheReadPrice},
{"per_request_price", iv.PerRequestPrice},
}
for _, p := range prices {
if p.val != nil && *p.val < 0 {
return fmt.Errorf("interval #%d: %s must be >= 0", idx+1, p.name)
}
}
return nil
}
// validateIntervalOverlap 校验排序后的区间列表无重叠,且无界区间在最后
func validateIntervalOverlap(sorted []PricingInterval) error {
for i, iv := range sorted {
// 无界区间必须是最后一个
if iv.MaxTokens == nil && i < len(sorted)-1 {
return fmt.Errorf("interval #%d: unbounded interval (max_tokens=null) must be the last one",
i+1)
}
if i == 0 {
continue
}
prev := sorted[i-1]
// 检查重叠:前一个区间的上界 > 当前区间的下界则重叠
// (min, max] 语义:prev 覆盖 (prev.Min, prev.Max],cur 覆盖 (cur.Min, cur.Max]
if prev.MaxTokens == nil || *prev.MaxTokens > iv.MinTokens {
return fmt.Errorf("interval #%d and #%d overlap: prev max=%s > cur min=%d",
i, i+1, formatMaxTokensLabel(prev.MaxTokens), iv.MinTokens)
}
}
return nil
}
func formatMaxTokensLabel(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中) // ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中)
type ChannelUsageFields struct { type ChannelUsageFields struct {
ChannelID int64 // 渠道 ID(0 = 无渠道) ChannelID int64 // 渠道 ID(0 = 无渠道)
......
...@@ -198,13 +198,18 @@ func newEmptyChannelCache() *channelCache { ...@@ -198,13 +198,18 @@ func newEmptyChannelCache() *channelCache {
// expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。 // expandPricingToCache 将渠道的模型定价展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。 // antigravity 平台同时服务 Claude 和 Gemini 模型,需匹配 anthropic/gemini 的定价条目。
// 缓存 key 使用定价条目的原始平台(pricing.Platform),而非分组平台,
// 避免跨平台同名模型(如 anthropic 和 gemini 都有 "model-x")互相覆盖。
// 查找时通过 lookupPricingAcrossPlatforms() 依次尝试所有匹配平台。
func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for j := range ch.ModelPricing { for j := range ch.ModelPricing {
pricing := &ch.ModelPricing[j] pricing := &ch.ModelPricing[j]
if !isPlatformPricingMatch(platform, pricing.Platform) { if !isPlatformPricingMatch(platform, pricing.Platform) {
continue // 跳过非本平台的定价 continue // 跳过非本平台的定价
} }
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform} // 使用定价条目的原始平台作为缓存 key,防止跨平台同名模型冲突
pricingPlatform := pricing.Platform
gpKey := channelGroupPlatformKey{groupID: gid, platform: pricingPlatform}
for _, model := range pricing.Models { for _, model := range pricing.Models {
if strings.HasSuffix(model, "*") { if strings.HasSuffix(model, "*") {
prefix := strings.ToLower(strings.TrimSuffix(model, "*")) prefix := strings.ToLower(strings.TrimSuffix(model, "*"))
...@@ -213,7 +218,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform ...@@ -213,7 +218,7 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
pricing: pricing, pricing: pricing,
}) })
} else { } else {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)} key := channelModelKey{groupID: gid, platform: pricingPlatform, model: strings.ToLower(model)}
cache.pricingByGroupModel[key] = pricing cache.pricingByGroupModel[key] = pricing
} }
} }
...@@ -222,13 +227,15 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform ...@@ -222,13 +227,15 @@ func expandPricingToCache(cache *channelCache, ch *Channel, gid int64, platform
// expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。 // expandMappingToCache 将渠道的模型映射展开到缓存(按分组+平台维度)。
// antigravity 平台同时服务 Claude 和 Gemini 模型。 // antigravity 平台同时服务 Claude 和 Gemini 模型。
// 缓存 key 使用映射条目的原始平台(mappingPlatform),避免跨平台同名映射覆盖。
func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) { func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform string) {
for _, mappingPlatform := range matchingPlatforms(platform) { for _, mappingPlatform := range matchingPlatforms(platform) {
platformMapping, ok := ch.ModelMapping[mappingPlatform] platformMapping, ok := ch.ModelMapping[mappingPlatform]
if !ok { if !ok {
continue continue
} }
gpKey := channelGroupPlatformKey{groupID: gid, platform: platform} // 使用映射条目的原始平台作为缓存 key,防止跨平台同名映射冲突
gpKey := channelGroupPlatformKey{groupID: gid, platform: mappingPlatform}
for src, dst := range platformMapping { for src, dst := range platformMapping {
if strings.HasSuffix(src, "*") { if strings.HasSuffix(src, "*") {
prefix := strings.ToLower(strings.TrimSuffix(src, "*")) prefix := strings.ToLower(strings.TrimSuffix(src, "*"))
...@@ -237,7 +244,7 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform ...@@ -237,7 +244,7 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
target: dst, target: dst,
}) })
} else { } else {
key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(src)} key := channelModelKey{groupID: gid, platform: mappingPlatform, model: strings.ToLower(src)}
cache.mappingByGroupModel[key] = dst cache.mappingByGroupModel[key] = dst
} }
} }
...@@ -349,6 +356,43 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower ...@@ -349,6 +356,43 @@ func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower
return "" return ""
} }
// lookupPricingAcrossPlatforms 在所有匹配平台中查找模型定价。
// antigravity 分组的缓存 key 使用定价条目的原始平台,因此查找时需依次尝试
// matchingPlatforms() 返回的所有平台(antigravity → anthropic → gemini),
// 返回第一个命中的结果。非 antigravity 平台只尝试自身。
func lookupPricingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) *ChannelModelPricing {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
if pricing, ok := cache.pricingByGroupModel[key]; ok {
return pricing
}
}
// 精确查找全部失败,依次尝试通配符匹配
for _, p := range matchingPlatforms(groupPlatform) {
if pricing := cache.matchWildcard(groupID, p, modelLower); pricing != nil {
return pricing
}
}
return nil
}
// lookupMappingAcrossPlatforms 在所有匹配平台中查找模型映射。
// 逻辑与 lookupPricingAcrossPlatforms 相同:先精确查找,再通配符。
func lookupMappingAcrossPlatforms(cache *channelCache, groupID int64, groupPlatform, modelLower string) string {
for _, p := range matchingPlatforms(groupPlatform) {
key := channelModelKey{groupID: groupID, platform: p, model: modelLower}
if mapped, ok := cache.mappingByGroupModel[key]; ok {
return mapped
}
}
for _, p := range matchingPlatforms(groupPlatform) {
if mapped := cache.matchWildcardMapping(groupID, p, modelLower); mapped != "" {
return mapped
}
}
return ""
}
// GetChannelForGroup 获取分组关联的渠道(热路径 O(1)) // GetChannelForGroup 获取分组关联的渠道(热路径 O(1))
func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) { func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) {
cache, err := s.loadCache(ctx) cache, err := s.loadCache(ctx)
...@@ -389,7 +433,9 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64) ...@@ -389,7 +433,9 @@ func (s *ChannelService) lookupGroupChannel(ctx context.Context, groupID int64)
}, nil }, nil
} }
// GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1)) // GetChannelModelPricing 获取指定分组+模型的渠道定价(热路径 O(1))。
// antigravity 分组依次尝试所有匹配平台(antigravity → anthropic → gemini),
// 确保跨平台同名模型各自独立匹配。
func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing { func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int64, model string) *ChannelModelPricing {
lk, err := s.lookupGroupChannel(ctx, groupID) lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil { if err != nil {
...@@ -401,14 +447,9 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int ...@@ -401,14 +447,9 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int
} }
modelLower := strings.ToLower(model) modelLower := strings.ToLower(model)
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower} pricing := lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower)
pricing, ok := lk.cache.pricingByGroupModel[key] if pricing == nil {
if !ok { return nil
// 精确查找失败,尝试通配符匹配
pricing = lk.cache.matchWildcard(groupID, lk.platform, modelLower)
if pricing == nil {
return nil
}
} }
cp := pricing.Clone() cp := pricing.Clone()
...@@ -453,7 +494,8 @@ func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, g ...@@ -453,7 +494,8 @@ func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, g
return resolveMapping(lk, *groupID, model), false return resolveMapping(lk, *groupID, model), false
} }
// resolveMapping 基于已查找的渠道信息解析模型映射 // resolveMapping 基于已查找的渠道信息解析模型映射。
// antigravity 分组依次尝试所有匹配平台,确保跨平台同名映射各自独立。
func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult { func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappingResult {
result := ChannelMappingResult{ result := ChannelMappingResult{
MappedModel: model, MappedModel: model,
...@@ -465,11 +507,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi ...@@ -465,11 +507,7 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
} }
modelLower := strings.ToLower(model) modelLower := strings.ToLower(model)
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower} if mapped := lookupMappingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower); mapped != "" {
if mapped, ok := lk.cache.mappingByGroupModel[key]; ok {
result.MappedModel = mapped
result.Mapped = true
} else if mapped := lk.cache.matchWildcardMapping(groupID, lk.platform, modelLower); mapped != "" {
result.MappedModel = mapped result.MappedModel = mapped
result.Mapped = true result.Mapped = true
} }
...@@ -477,19 +515,15 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi ...@@ -477,19 +515,15 @@ func resolveMapping(lk *channelLookup, groupID int64, model string) ChannelMappi
return result return result
} }
// checkRestricted 基于已查找的渠道信息检查模型是否被限制 // checkRestricted 基于已查找的渠道信息检查模型是否被限制。
// antigravity 分组依次尝试所有匹配平台的定价列表。
func checkRestricted(lk *channelLookup, groupID int64, model string) bool { func checkRestricted(lk *channelLookup, groupID int64, model string) bool {
if !lk.channel.RestrictModels { if !lk.channel.RestrictModels {
return false return false
} }
// 检查模型是否在定价列表中
modelLower := strings.ToLower(model) modelLower := strings.ToLower(model)
key := channelModelKey{groupID: groupID, platform: lk.platform, model: modelLower} // 使用与查找定价相同的跨平台逻辑
if _, exists := lk.cache.pricingByGroupModel[key]; exists { if lookupPricingAcrossPlatforms(lk.cache, groupID, lk.platform, modelLower) != nil {
return false
}
// 精确查找失败,尝试通配符匹配
if lk.cache.matchWildcard(groupID, lk.platform, modelLower) != nil {
return false return false
} }
return true return true
...@@ -550,6 +584,9 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) ...@@ -550,6 +584,9 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
if err := validateNoConflictingModels(channel.ModelPricing); err != nil { if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err return nil, err
} }
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err return nil, err
} }
...@@ -624,6 +661,9 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan ...@@ -624,6 +661,9 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
if err := validateNoConflictingModels(channel.ModelPricing); err != nil { if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err return nil, err
} }
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
return nil, err return nil, err
} }
...@@ -756,6 +796,19 @@ func validateNoConflictingMappings(mapping map[string]map[string]string) error { ...@@ -756,6 +796,19 @@ func validateNoConflictingMappings(mapping map[string]map[string]string) error {
return nil return nil
} }
func validatePricingIntervals(pricingList []ChannelModelPricing) error {
for _, pricing := range pricingList {
if err := ValidateIntervals(pricing.Intervals); err != nil {
return infraerrors.BadRequest(
"INVALID_PRICING_INTERVALS",
fmt.Sprintf("invalid pricing intervals for platform '%s' models %v: %v",
pricing.Platform, pricing.Models, err),
)
}
}
return nil
}
// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误 // detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误
func detectConflicts(entries []modelEntry, platform, errCode, label string) error { func detectConflicts(entries []modelEntry, platform, errCode, label string) error {
for i := 0; i < len(entries); i++ { for i := 0; i < len(entries); i++ {
......
...@@ -1401,6 +1401,32 @@ func TestCreate_DuplicateModel(t *testing.T) { ...@@ -1401,6 +1401,32 @@ func TestCreate_DuplicateModel(t *testing.T) {
require.Contains(t, err.Error(), "claude-opus-4") require.Contains(t, err.Error(), "claude-opus-4")
} }
func TestCreate_InvalidPricingIntervals(t *testing.T) {
repo := &mockChannelRepository{
existsByNameFn: func(_ context.Context, _ string) (bool, error) {
return false, nil
},
}
svc := newTestChannelService(repo)
_, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "new-channel",
ModelPricing: []ChannelModelPricing{
{
Platform: "anthropic",
Models: []string{"claude-opus-4"},
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(2000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 1000, MaxTokens: testPtrInt(3000), InputPrice: testPtrFloat64(2e-6)},
},
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS")
require.Contains(t, err.Error(), "overlap")
}
func TestCreate_DefaultBillingModelSource(t *testing.T) { func TestCreate_DefaultBillingModelSource(t *testing.T) {
var capturedChannel *Channel var capturedChannel *Channel
repo := &mockChannelRepository{ repo := &mockChannelRepository{
...@@ -1592,6 +1618,37 @@ func TestUpdate_DuplicateModel(t *testing.T) { ...@@ -1592,6 +1618,37 @@ func TestUpdate_DuplicateModel(t *testing.T) {
require.Contains(t, err.Error(), "claude-opus-4") require.Contains(t, err.Error(), "claude-opus-4")
} }
func TestUpdate_InvalidPricingIntervals(t *testing.T) {
existing := &Channel{
ID: 1,
Name: "original",
Status: StatusActive,
}
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
return existing.Clone(), nil
},
}
svc := newTestChannelService(repo)
invalidPricing := []ChannelModelPricing{
{
Platform: "anthropic",
Models: []string{"claude-opus-4"},
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 2000, MaxTokens: testPtrInt(4000), InputPrice: testPtrFloat64(2e-6)},
},
},
}
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
ModelPricing: &invalidPricing,
})
require.Error(t, err)
require.Contains(t, err.Error(), "INVALID_PRICING_INTERVALS")
require.Contains(t, err.Error(), "unbounded")
}
func TestUpdate_InvalidatesChannelCache(t *testing.T) { func TestUpdate_InvalidatesChannelCache(t *testing.T) {
existing := &Channel{ existing := &Channel{
ID: 1, ID: 1,
...@@ -1984,3 +2041,144 @@ func TestResolveChannelMapping_AntigravityCrossPlatform(t *testing.T) { ...@@ -1984,3 +2041,144 @@ func TestResolveChannelMapping_AntigravityCrossPlatform(t *testing.T) {
require.Equal(t, "claude-opus-4-6", result.MappedModel) require.Equal(t, "claude-opus-4-6", result.MappedModel)
require.Equal(t, int64(1), result.ChannelID) require.Equal(t, int64(1), result.ChannelID)
} }
// ===========================================================================
// 11. Antigravity cross-platform same-name model — no overwrite
// ===========================================================================
func TestGetChannelModelPricing_AntigravitySameModelDifferentPlatforms(t *testing.T) {
// anthropic 和 gemini 都定义了同名模型 "shared-model",价格不同。
// antigravity 分组应能分别查到各自的定价,而不是后者覆盖前者。
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 200, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)},
{ID: 201, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
// antigravity 分组查找 "shared-model":应命中第一个匹配(按 matchingPlatforms 顺序 antigravity→anthropic→gemini)
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
require.NotNil(t, result, "antigravity group should find pricing for shared-model")
// 第一个匹配应该是 anthropic(matchingPlatforms 返回 [antigravity, anthropic, gemini])
require.Equal(t, int64(200), result.ID)
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
}
func TestGetChannelModelPricing_AntigravityOnlyGeminiPricing(t *testing.T) {
// 只有 gemini 平台定义了模型 "gemini-model"。
// antigravity 分组应能查到 gemini 的定价。
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 300, Platform: PlatformGemini, Models: []string{"gemini-model"}, InputPrice: testPtrFloat64(2e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "gemini-model")
require.NotNil(t, result, "antigravity group should find gemini pricing")
require.Equal(t, int64(300), result.ID)
require.InDelta(t, 2e-6, *result.InputPrice, 1e-12)
}
func TestGetChannelModelPricing_AntigravityWildcardCrossPlatformNoOverwrite(t *testing.T) {
// anthropic 和 gemini 都有 "shared-*" 通配符定价,价格不同。
// antigravity 分组查找 "shared-model" 应命中第一个匹配而非被覆盖。
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 400, Platform: PlatformAnthropic, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(10e-6)},
{ID: 401, Platform: PlatformGemini, Models: []string{"shared-*"}, InputPrice: testPtrFloat64(5e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
require.NotNil(t, result, "antigravity group should find wildcard pricing for shared-model")
// 两个通配符都存在,应命中 anthropic 的(matchingPlatforms 顺序)
require.Equal(t, int64(400), result.ID)
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
}
func TestResolveChannelMapping_AntigravitySameModelDifferentPlatforms(t *testing.T) {
// anthropic 和 gemini 都定义了同名模型映射 "alias" → 不同目标。
// antigravity 分组应命中 anthropic 的映射(按 matchingPlatforms 顺序)。
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
ModelMapping: map[string]map[string]string{
PlatformAnthropic: {"alias": "anthropic-target"},
PlatformGemini: {"alias": "gemini-target"},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
result := svc.ResolveChannelMapping(context.Background(), 10, "alias")
require.True(t, result.Mapped)
require.Equal(t, "anthropic-target", result.MappedModel)
}
func TestCheckRestricted_AntigravitySameModelDifferentPlatforms(t *testing.T) {
// anthropic 和 gemini 都定义了同名模型 "shared-model"。
// antigravity 分组启用了 RestrictModels,"shared-model" 应不被限制。
ch := Channel{
ID: 1,
Status: StatusActive,
RestrictModels: true,
GroupIDs: []int64{10},
ModelPricing: []ChannelModelPricing{
{ID: 500, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)},
{ID: 501, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity})
svc := newTestChannelService(repo)
restricted := svc.IsModelRestricted(context.Background(), 10, "shared-model")
require.False(t, restricted, "shared-model should not be restricted for antigravity")
// 未定义的模型应被限制
restricted = svc.IsModelRestricted(context.Background(), 10, "unknown-model")
require.True(t, restricted, "unknown-model should be restricted for antigravity")
}
func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) {
// 确保非 antigravity 平台的行为不受影响。
// anthropic 分组只能看到 anthropic 的定价,看不到 gemini 的。
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10, 20},
ModelPricing: []ChannelModelPricing{
{ID: 600, Platform: PlatformAnthropic, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(10e-6)},
{ID: 601, Platform: PlatformGemini, Models: []string{"shared-model"}, InputPrice: testPtrFloat64(5e-6)},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAnthropic, 20: PlatformGemini})
svc := newTestChannelService(repo)
// anthropic 分组应该只看到 anthropic 的定价
result := svc.GetChannelModelPricing(context.Background(), 10, "shared-model")
require.NotNil(t, result)
require.Equal(t, int64(600), result.ID)
require.InDelta(t, 10e-6, *result.InputPrice, 1e-12)
// gemini 分组应该只看到 gemini 的定价
result = svc.GetChannelModelPricing(context.Background(), 20, "shared-model")
require.NotNil(t, result)
require.Equal(t, int64(601), result.ID)
require.InDelta(t, 5e-6, *result.InputPrice, 1e-12)
}
...@@ -307,3 +307,129 @@ func TestChannelClone_EdgeCases(t *testing.T) { ...@@ -307,3 +307,129 @@ func TestChannelClone_EdgeCases(t *testing.T) {
require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"]) require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"])
}) })
} }
// --- ValidateIntervals ---
func TestValidateIntervals_Empty(t *testing.T) {
require.NoError(t, ValidateIntervals(nil))
require.NoError(t, ValidateIntervals([]PricingInterval{}))
}
func TestValidateIntervals_ValidIntervals(t *testing.T) {
tests := []struct {
name string
intervals []PricingInterval
}{
{
name: "single bounded interval",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
},
},
{
name: "two intervals with gap",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
},
},
{
name: "two contiguous intervals",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
},
},
{
name: "unsorted input (auto-sorted by validator)",
intervals: []PricingInterval{
{MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)},
{MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)},
},
},
{
name: "single unbounded interval",
intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.NoError(t, ValidateIntervals(tt.intervals))
})
}
}
func TestValidateIntervals_NegativeMinTokens(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: -1, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "min_tokens")
require.Contains(t, err.Error(), ">= 0")
}
func TestValidateIntervals_MaxTokensZero(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(0), InputPrice: testPtrFloat64(1e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "max_tokens")
require.Contains(t, err.Error(), "> 0")
}
func TestValidateIntervals_MaxLessThanMin(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 100, MaxTokens: testPtrInt(50), InputPrice: testPtrFloat64(1e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "max_tokens")
require.Contains(t, err.Error(), "> min_tokens")
}
func TestValidateIntervals_MaxEqualsMin(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 100, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "max_tokens")
require.Contains(t, err.Error(), "> min_tokens")
}
func TestValidateIntervals_NegativePrice(t *testing.T) {
negPrice := -0.01
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(100), InputPrice: &negPrice},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "input_price")
require.Contains(t, err.Error(), ">= 0")
}
func TestValidateIntervals_OverlappingIntervals(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(200), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 100, MaxTokens: testPtrInt(300), InputPrice: testPtrFloat64(2e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "overlap")
}
func TestValidateIntervals_UnboundedNotLast(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: testPtrInt(256000), InputPrice: testPtrFloat64(2e-6)},
}
err := ValidateIntervals(intervals)
require.Error(t, err)
require.Contains(t, err.Error(), "unbounded")
require.Contains(t, err.Error(), "last")
}
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestSelectAccountForModelWithExclusions_UsesFallbackGroupForChannelRestriction(t *testing.T) {
t.Parallel()
groupID := int64(10)
fallbackID := int64(11)
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{fallbackID},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{
fallbackID: PlatformAnthropic,
}))
accountRepo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range accountRepo.accounts {
accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
},
fallbackID: {
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
},
},
}
svc := &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
channelService: channelSvc,
cfg: testConfig(),
}
ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID])
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-sonnet-4-6", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(1), account.ID)
}
func TestSelectAccountWithLoadAwareness_UsesFallbackGroupForChannelRestriction(t *testing.T) {
t.Parallel()
groupID := int64(10)
fallbackID := int64(11)
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{fallbackID},
RestrictModels: true,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformAnthropic, Models: []string{"claude-sonnet-4-6"}},
},
}
channelSvc := newTestChannelService(makeStandardRepo(ch, map[int64]string{
fallbackID: PlatformAnthropic,
}))
accountRepo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range accountRepo.accounts {
accountRepo.accountsByID[accountRepo.accounts[i].ID] = &accountRepo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
},
fallbackID: {
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
},
},
}
svc := &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
channelService: channelSvc,
cfg: testConfig(),
}
ctx := context.WithValue(context.Background(), ctxkey.Group, groupRepo.groups[groupID])
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-sonnet-4-6", nil, "", 0)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID)
}
...@@ -1178,11 +1178,6 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int ...@@ -1178,11 +1178,6 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 渠道定价限制预检查(requested / channel_mapped 基准)
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 优先检查 context 中的强制平台(/antigravity 路由) // 优先检查 context 中的强制平台(/antigravity 路由)
var platform string var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
...@@ -1201,6 +1196,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -1201,6 +1196,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
platform = PlatformAnthropic platform = PlatformAnthropic
} }
// Claude Code 限制可能已将 groupID 解析为 fallback group,
// 渠道限制预检查必须使用解析后的分组。
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度 // 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
...@@ -1217,11 +1218,6 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -1217,11 +1218,6 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID // metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
// sub2apiUserID: 系统用户 ID,用于二维亲和调度 // sub2apiUserID: 系统用户 ID,用于二维亲和调度
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) { func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
// 渠道定价限制预检查(requested / channel_mapped 基准)
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 调试日志:记录调度入口参数 // 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs)) excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs { for id := range excludedIDs {
...@@ -1242,6 +1238,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1242,6 +1238,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
ctx = s.withGroupContext(ctx, group) ctx = s.withGroupContext(ctx, group)
// Claude Code 限制可能已将 groupID 解析为 fallback group,
// 渠道限制预检查必须使用解析后的分组。
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
var stickyAccountID int64 var stickyAccountID int64
if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 {
stickyAccountID = prefetch stickyAccountID = prefetch
......
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestOpenAISelectAccountForModelWithExclusions_ChannelMappedRestrictionRejectsEarly(t *testing.T) {
t.Parallel()
channelSvc := newTestChannelService(makeStandardRepo(Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceChannelMapped,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformOpenAI, Models: []string{"gpt-4o"}},
},
ModelMapping: map[string]map[string]string{
PlatformOpenAI: {"gpt-4.1": "o3-mini"},
},
}, map[int64]string{10: PlatformOpenAI}))
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true},
}},
channelService: channelSvc,
}
groupID := int64(10)
_, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil)
require.ErrorIs(t, err, ErrNoAvailableAccounts)
require.Contains(t, err.Error(), "channel pricing restriction")
}
func TestOpenAISelectAccountForModelWithExclusions_UpstreamRestrictionSkipsDisallowedAccount(t *testing.T) {
t.Parallel()
channelSvc := newTestChannelService(makeStandardRepo(Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceUpstream,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformOpenAI, Models: []string{"o3-mini"}},
},
}, map[int64]string{10: PlatformOpenAI}))
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
{
ID: 1,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Priority: 10,
Credentials: map[string]any{
"model_mapping": map[string]any{"gpt-4.1": "gpt-4o"},
},
},
{
ID: 2,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Priority: 20,
Credentials: map[string]any{
"model_mapping": map[string]any{"gpt-4.1": "o3-mini"},
},
},
}},
channelService: channelSvc,
}
groupID := int64(10)
account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "", "gpt-4.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(2), account.ID)
}
func TestOpenAISelectAccountForModelWithExclusions_StickyRestrictedUpstreamFallsBack(t *testing.T) {
t.Parallel()
channelSvc := newTestChannelService(makeStandardRepo(Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10},
RestrictModels: true,
BillingModelSource: BillingModelSourceUpstream,
ModelPricing: []ChannelModelPricing{
{Platform: PlatformOpenAI, Models: []string{"o3-mini"}},
},
}, map[int64]string{10: PlatformOpenAI}))
cache := &stubGatewayCache{
sessionBindings: map[string]int64{"openai:sticky-session": 1},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{
{
ID: 1,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Priority: 10,
Credentials: map[string]any{
"model_mapping": map[string]any{"gpt-4.1": "gpt-4o"},
},
},
{
ID: 2,
Platform: PlatformOpenAI,
Status: StatusActive,
Schedulable: true,
Priority: 20,
Credentials: map[string]any{
"model_mapping": map[string]any{"gpt-4.1": "o3-mini"},
},
},
}},
channelService: channelSvc,
cache: cache,
}
groupID := int64(10)
account, err := svc.SelectAccountForModelWithExclusions(context.Background(), &groupID, "sticky-session", "gpt-4.1", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, int64(2), account.ID)
require.Equal(t, 1, cache.deletedSessions["openai:sticky-session"])
require.Equal(t, int64(2), cache.sessionBindings["openai:sticky-session"])
}
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"math/rand" "math/rand"
"net/http" "net/http"
"sort" "sort"
...@@ -423,6 +424,44 @@ func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Cont ...@@ -423,6 +424,44 @@ func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Cont
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model) return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
} }
func (s *OpenAIGatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool {
if groupID == nil || s.channelService == nil || requestedModel == "" {
return false
}
mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel)
billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel)
if billingModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, *groupID, billingModel)
}
func (s *OpenAIGatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool {
if s.channelService == nil {
return false
}
upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "")
if upstreamModel == "" {
return false
}
return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel)
}
func (s *OpenAIGatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool {
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil {
slog.Warn("failed to check openai channel upstream restriction", "group_id", *groupID, "error", err)
return false
}
if ch == nil || !ch.RestrictModels {
return false
}
return ch.BillingModelSource == BillingModelSourceUpstream
}
// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。 // ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。
func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
return ReplaceModelInBody(body, newModel) return ReplaceModelInBody(body, newModel)
...@@ -1162,6 +1201,10 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C ...@@ -1162,6 +1201,10 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
} }
func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) { func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
// 1. 尝试粘性会话命中 // 1. 尝试粘性会话命中
// Try sticky session hit // Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil { if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil {
...@@ -1177,7 +1220,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C ...@@ -1177,7 +1220,7 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
// 3. 按优先级 + LRU 选择最佳账号 // 3. 按优先级 + LRU 选择最佳账号
// Select by priority + LRU // Select by priority + LRU
selected := s.selectBestAccount(ctx, accounts, requestedModel, excludedIDs) selected := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs)
if selected == nil { if selected == nil {
if requestedModel != "" { if requestedModel != "" {
...@@ -1243,6 +1286,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID ...@@ -1243,6 +1286,11 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil return nil
} }
if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) &&
s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
// 刷新会话 TTL 并返回账号 // 刷新会话 TTL 并返回账号
// Refresh session TTL and return account // Refresh session TTL and return account
...@@ -1255,8 +1303,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID ...@@ -1255,8 +1303,9 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// //
// selectBestAccount selects the best account from candidates (priority + LRU). // selectBestAccount selects the best account from candidates (priority + LRU).
// Returns nil if no available account. // Returns nil if no available account.
func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
var selected *Account var selected *Account
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
...@@ -1275,6 +1324,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [ ...@@ -1275,6 +1324,9 @@ func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, accounts [
if fresh == nil { if fresh == nil {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
// 选择优先级最高且最久未使用的账号 // 选择优先级最高且最久未使用的账号
// Select highest priority and least recently used // Select highest priority and least recently used
...@@ -1326,7 +1378,12 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool ...@@ -1326,7 +1378,12 @@ func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. // SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
}
cfg := s.schedulingConfig() cfg := s.schedulingConfig()
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
var stickyAccountID int64 var stickyAccountID int64
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil {
...@@ -1402,6 +1459,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1402,6 +1459,8 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel) account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel)
if account == nil { if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else { } else {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
...@@ -1447,6 +1506,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1447,6 +1506,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if requestedModel != "" && !acc.IsModelSupported(requestedModel) { if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) {
continue
}
candidates = append(candidates, acc) candidates = append(candidates, acc)
} }
...@@ -1471,6 +1533,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1471,6 +1533,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil { if fresh == nil {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
if sessionHash != "" { if sessionHash != "" {
...@@ -1525,6 +1590,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1525,6 +1590,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil { if fresh == nil {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
if sessionHash != "" { if sessionHash != "" {
...@@ -1547,6 +1615,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex ...@@ -1547,6 +1615,9 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
if fresh == nil { if fresh == nil {
continue continue
} }
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel) {
continue
}
return &AccountSelectionResult{ return &AccountSelectionResult{
Account: fresh, Account: fresh,
WaitPlan: &AccountWaitPlan{ WaitPlan: &AccountWaitPlan{
......
...@@ -113,6 +113,70 @@ export function findModelConflict(models: string[]): [string, string] | null { ...@@ -113,6 +113,70 @@ export function findModelConflict(models: string[]): [string, string] | null {
return null return null
} }
// ── 区间校验 ──────────────────────────────────────────────
/** 校验区间列表的合法性,返回错误消息;通过则返回 null */
export function validateIntervals(intervals: IntervalFormEntry[]): string | null {
if (!intervals || intervals.length === 0) return null
// 按 min_tokens 排序(不修改原数组)
const sorted = [...intervals].sort((a, b) => a.min_tokens - b.min_tokens)
for (let i = 0; i < sorted.length; i++) {
const err = validateSingleInterval(sorted[i], i)
if (err) return err
}
return checkIntervalOverlap(sorted)
}
function validateSingleInterval(iv: IntervalFormEntry, idx: number): string | null {
if (iv.min_tokens < 0) {
return `区间 #${idx + 1}: 最小 token 数 (${iv.min_tokens}) 不能为负数`
}
if (iv.max_tokens != null) {
if (iv.max_tokens <= 0) {
return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于 0`
}
if (iv.max_tokens <= iv.min_tokens) {
return `区间 #${idx + 1}: 最大 token 数 (${iv.max_tokens}) 必须大于最小 token 数 (${iv.min_tokens})`
}
}
return validateIntervalPrices(iv, idx)
}
function validateIntervalPrices(iv: IntervalFormEntry, idx: number): string | null {
const prices: [string, number | string | null][] = [
['输入价格', iv.input_price],
['输出价格', iv.output_price],
['缓存写入价格', iv.cache_write_price],
['缓存读取价格', iv.cache_read_price],
['单次价格', iv.per_request_price],
]
for (const [name, val] of prices) {
if (val != null && val !== '' && Number(val) < 0) {
return `区间 #${idx + 1}: ${name}不能为负数`
}
}
return null
}
function checkIntervalOverlap(sorted: IntervalFormEntry[]): string | null {
for (let i = 0; i < sorted.length; i++) {
// 无上限区间必须是最后一个
if (sorted[i].max_tokens == null && i < sorted.length - 1) {
return `区间 #${i + 1}: 无上限区间(最大 token 数为空)只能是最后一个`
}
if (i === 0) continue
const prev = sorted[i - 1]
// (min, max] 语义:前一个区间上界 > 当前区间下界则重叠
if (prev.max_tokens == null || prev.max_tokens > sorted[i].min_tokens) {
const prevMax = prev.max_tokens == null ? '' : String(prev.max_tokens)
return `区间 #${i} 和 #${i + 1} 重叠:前一个区间上界 (${prevMax}) 大于当前区间下界 (${sorted[i].min_tokens})`
}
}
return null
}
/** 平台对应的模型 tag 样式(背景+文字) */ /** 平台对应的模型 tag 样式(背景+文字) */
export function getPlatformTagClass(platform: string): string { export function getPlatformTagClass(platform: string): string {
switch (platform) { switch (platform) {
......
...@@ -418,7 +418,7 @@ import { useAppStore } from '@/stores/app' ...@@ -418,7 +418,7 @@ import { useAppStore } from '@/stores/app'
import { adminAPI } from '@/api/admin' import { adminAPI } from '@/api/admin'
import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels' import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels'
import type { PricingFormEntry } from '@/components/admin/channel/types' import type { PricingFormEntry } from '@/components/admin/channel/types'
import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict } from '@/components/admin/channel/types' import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types'
import type { AdminGroup, GroupPlatform } from '@/types' import type { AdminGroup, GroupPlatform } from '@/types'
import type { Column } from '@/components/common/types' import type { Column } from '@/components/common/types'
import AppLayout from '@/components/layout/AppLayout.vue' import AppLayout from '@/components/layout/AppLayout.vue'
...@@ -922,6 +922,21 @@ async function handleSubmit() { ...@@ -922,6 +922,21 @@ async function handleSubmit() {
} }
} }
// 校验区间合法性(范围、重叠等)
for (const section of form.platforms.filter(s => s.enabled)) {
for (const entry of section.model_pricing) {
if (!entry.intervals || entry.intervals.length === 0) continue
const intervalErr = validateIntervals(entry.intervals)
if (intervalErr) {
const platformLabel = t('admin.groups.platforms.' + section.platform, section.platform)
const modelLabel = entry.models.join(', ') || '未命名'
appStore.showError(`${platformLabel} - ${modelLabel}: ${intervalErr}`)
activeTab.value = section.platform
return
}
}
}
const { group_ids, model_pricing, model_mapping } = formToAPI() const { group_ids, model_pricing, model_mapping } = formToAPI()
submitting.value = true submitting.value = true
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment