Unverified Commit c5c12d4c authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Revert "feat(gateway): 实现负载感知的账号调度优化 (#114)" (#117)

This reverts commit 8d252303.
parent 8d252303
...@@ -99,7 +99,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -99,7 +99,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream) accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) concurrencyService := service.NewConcurrencyService(concurrencyCache)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService) oAuthHandler := admin.NewOAuthHandler(oAuthService)
...@@ -127,10 +127,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { ...@@ -127,10 +127,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
timingWheelService := service.ProvideTimingWheelService() timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
......
...@@ -3,7 +3,6 @@ package config ...@@ -3,7 +3,6 @@ package config
import ( import (
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
...@@ -120,37 +119,6 @@ type GatewayConfig struct { ...@@ -120,37 +119,6 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期 // 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400 bool `mapstructure:"failover_on_400"`
// Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
}
// GatewaySchedulingConfig accounts scheduling configuration.
type GatewaySchedulingConfig struct {
// 粘性会话排队配置
StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"`
StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"`
// 兜底排队配置
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
// 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
// 过期槽位清理周期(0 表示禁用)
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
} }
func (s *ServerConfig) Address() string { func (s *ServerConfig) Address() string {
...@@ -345,10 +313,6 @@ func setDefaults() { ...@@ -345,10 +313,6 @@ func setDefaults() {
// Gateway // Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化) // HTTP 上游连接池配置(针对 5000+ 并发用户优化)
...@@ -359,12 +323,6 @@ func setDefaults() { ...@@ -359,12 +323,6 @@ func setDefaults() {
viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900) viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 15) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 45*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
// TokenRefresh // TokenRefresh
viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.enabled", true)
...@@ -453,21 +411,6 @@ func (c *Config) Validate() error { ...@@ -453,21 +411,6 @@ func (c *Config) Validate() error {
if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { if c.Gateway.ConcurrencySlotTTLMinutes <= 0 {
return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive")
} }
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
}
if c.Gateway.Scheduling.StickySessionWaitTimeout <= 0 {
return fmt.Errorf("gateway.scheduling.sticky_session_wait_timeout must be positive")
}
if c.Gateway.Scheduling.FallbackWaitTimeout <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_wait_timeout must be positive")
}
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
}
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
}
return nil return nil
} }
......
package config package config
import ( import "testing"
"testing"
"time"
"github.com/spf13/viper"
)
func TestNormalizeRunMode(t *testing.T) { func TestNormalizeRunMode(t *testing.T) {
tests := []struct { tests := []struct {
...@@ -26,45 +21,3 @@ func TestNormalizeRunMode(t *testing.T) { ...@@ -26,45 +21,3 @@ func TestNormalizeRunMode(t *testing.T) {
} }
} }
} }
func TestLoadDefaultSchedulingConfig(t *testing.T) {
viper.Reset()
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 3 {
t.Fatalf("StickySessionMaxWaiting = %d, want 3", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
if cfg.Gateway.Scheduling.StickySessionWaitTimeout != 45*time.Second {
t.Fatalf("StickySessionWaitTimeout = %v, want 45s", cfg.Gateway.Scheduling.StickySessionWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackWaitTimeout != 30*time.Second {
t.Fatalf("FallbackWaitTimeout = %v, want 30s", cfg.Gateway.Scheduling.FallbackWaitTimeout)
}
if cfg.Gateway.Scheduling.FallbackMaxWaiting != 100 {
t.Fatalf("FallbackMaxWaiting = %d, want 100", cfg.Gateway.Scheduling.FallbackMaxWaiting)
}
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
t.Fatalf("LoadBatchEnabled = false, want true")
}
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
}
}
func TestLoadSchedulingConfigFromEnv(t *testing.T) {
viper.Reset()
t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5")
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Gateway.Scheduling.StickySessionMaxWaiting != 5 {
t.Fatalf("StickySessionMaxWaiting = %d, want 5", cfg.Gateway.Scheduling.StickySessionMaxWaiting)
}
}
...@@ -141,10 +141,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -141,10 +141,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} else if apiKey.Group != nil { } else if apiKey.Group != nil {
platform = apiKey.Group.Platform platform = apiKey.Group.Platform
} }
sessionKey := sessionHash
if platform == service.PlatformGemini && sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
if platform == service.PlatformGemini { if platform == service.PlatformGemini {
const maxAccountSwitches = 3 const maxAccountSwitches = 3
...@@ -153,7 +149,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -153,7 +149,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0 lastFailoverStatus := 0
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
...@@ -162,13 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -162,13 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return return
} }
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查) // 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream { if reqStream {
sendMockWarmupStream(c, reqModel) sendMockWarmupStream(c, reqModel)
} else { } else {
...@@ -178,47 +170,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -178,47 +170,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
accountReleaseFunc := selection.ReleaseFunc accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil { if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
...@@ -230,9 +187,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -230,9 +187,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
...@@ -277,7 +231,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -277,7 +231,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs) account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
...@@ -286,13 +240,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -286,13 +240,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return return
} }
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查) // 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
if reqStream { if reqStream {
sendMockWarmupStream(c, reqModel) sendMockWarmupStream(c, reqModel)
} else { } else {
...@@ -302,47 +252,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -302,47 +252,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
accountReleaseFunc := selection.ReleaseFunc accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil { if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
var result *service.ForwardResult var result *service.ForwardResult
...@@ -354,9 +269,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ...@@ -354,9 +269,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
......
...@@ -83,16 +83,6 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64 ...@@ -83,16 +83,6 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
h.concurrencyService.DecrementWaitCount(ctx, userID) h.concurrencyService.DecrementWaitCount(ctx, userID)
} }
// IncrementAccountWaitCount increments the wait count for an account
func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait)
}
// DecrementAccountWaitCount decrements the wait count for an account
func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait. // For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun. // streamStarted is updated if streaming response has begun.
...@@ -136,12 +126,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID ...@@ -136,12 +126,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller). // streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted) ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel() defer cancel()
// Determine if ping is needed (streaming + ping format defined) // Determine if ping is needed (streaming + ping format defined)
...@@ -215,11 +200,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType ...@@ -215,11 +200,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
} }
} }
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) {
return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted)
}
// nextBackoff 计算下一次退避时间 // nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应 // 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间 // current: 当前退避时间
......
...@@ -197,17 +197,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -197,17 +197,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body) // 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body) parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
const maxAccountSwitches = 3 const maxAccountSwitches = 3
switchCount := 0 switchCount := 0
failedAccountIDs := make(map[int64]struct{}) failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0 lastFailoverStatus := 0
for { for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs) account, err := h.geminiCompatService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, modelName, failedAccountIDs)
if err != nil { if err != nil {
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
...@@ -216,49 +212,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -216,49 +212,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
handleGeminiFailoverExhausted(c, lastFailoverStatus) handleGeminiFailoverExhausted(c, lastFailoverStatus)
return return
} }
account := selection.Account
// 4) account concurrency slot // 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, stream, &streamStarted)
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil { if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
stream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
googleError(c, http.StatusTooManyRequests, err.Error()) googleError(c, http.StatusTooManyRequests, err.Error())
return return
} }
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 5) forward (根据平台分流) // 5) forward (根据平台分流)
var result *service.ForwardResult var result *service.ForwardResult
...@@ -270,9 +230,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ...@@ -270,9 +230,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
......
...@@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for { for {
// Select account supporting the requested model // Select account supporting the requested model
log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel)
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) account, err := h.gatewayService.SelectAccountForModelWithExclusions(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
if err != nil { if err != nil {
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
if len(failedAccountIDs) == 0 { if len(failedAccountIDs) == 0 {
...@@ -156,60 +156,21 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -156,60 +156,21 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
return return
} }
account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot // 3. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil { if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// Forward request // Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
......
...@@ -54,9 +54,6 @@ type CustomToolSpec struct { ...@@ -54,9 +54,6 @@ type CustomToolSpec struct {
InputSchema map[string]any `json:"input_schema"` InputSchema map[string]any `json:"input_schema"`
} }
// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格)
type ClaudeCustomToolSpec = CustomToolSpec
// SystemBlock system prompt 数组形式的元素 // SystemBlock system prompt 数组形式的元素
type SystemBlock struct { type SystemBlock struct {
Type string `json:"type"` Type string `json:"type"`
......
...@@ -14,16 +14,13 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st ...@@ -14,16 +14,13 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// 用于存储 tool_use id -> name 映射 // 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround // 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 检测是否启用 thinking
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等),
// 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
// 1. 构建 contents // 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil { if err != nil {
...@@ -34,15 +31,7 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st ...@@ -34,15 +31,7 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model) systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
// 3. 构建 generationConfig // 3. 构建 generationConfig
reqForGen := claudeReq generationConfig := buildGenerationConfig(claudeReq)
if requestedThinkingEnabled && !allowDummyThought {
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
// shallow copy to avoid mutating caller's request
clone := *claudeReq
clone.Thinking = nil
reqForGen = &clone
}
generationConfig := buildGenerationConfig(reqForGen)
// 4. 构建 tools // 4. 构建 tools
tools := buildTools(claudeReq.Tools) tools := buildTools(claudeReq.Tools)
...@@ -161,7 +150,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT ...@@ -161,7 +150,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
parts = append([]GeminiPart{{ parts = append([]GeminiPart{{
Text: "Thinking...", Text: "Thinking...",
Thought: true, Thought: true,
ThoughtSignature: dummyThoughtSignature,
}}, parts...) }}, parts...)
} }
} }
...@@ -183,34 +171,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT ...@@ -183,34 +171,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator" const dummyThoughtSignature = "skip_thought_signature_validator"
// isValidThoughtSignature 验证 thought signature 是否有效
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
func isValidThoughtSignature(signature string) bool {
// 空字符串无效
if signature == "" {
return false
}
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
// 参考 Claude API 文档和实际观察到的有效 signature
if len(signature) < 40 {
log.Printf("[Debug] Signature too short: len=%d", len(signature))
return false
}
// 检查是否是有效的 base64 字符
// base64 字符集: A-Z, a-z, 0-9, +, /, =
for i, c := range signature {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
return false
}
}
return true
}
// buildParts 构建消息的 parts // buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) { func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
...@@ -239,30 +199,22 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -239,30 +199,22 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
} }
case "thinking": case "thinking":
if allowDummyThought { part := GeminiPart{
// Gemini 模型可以使用 dummy signature
parts = append(parts, GeminiPart{
Text: block.Thinking, Text: block.Thinking,
Thought: true, Thought: true,
ThoughtSignature: dummyThoughtSignature,
})
continue
} }
// 保留原有 signature(Claude 模型需要有效的 signature)
// Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。 if block.Signature != "" {
signature := strings.TrimSpace(block.Signature) part.ThoughtSignature = block.Signature
if signature == "" || signature == dummyThoughtSignature { } else if !allowDummyThought {
log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)") // Claude 模型需要有效 signature,跳过无 signature 的 thinking block
log.Printf("Warning: skipping thinking block without signature for Claude model")
continue continue
} else {
// Gemini 模型使用 dummy signature
part.ThoughtSignature = dummyThoughtSignature
} }
if !isValidThoughtSignature(signature) { parts = append(parts, part)
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
}
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: signature,
})
case "image": case "image":
if block.Source != nil && block.Source.Type == "base64" { if block.Source != nil && block.Source.Type == "base64" {
...@@ -287,9 +239,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ...@@ -287,9 +239,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
ID: block.ID, ID: block.ID,
}, },
} }
// 只有 Gemini 模型使用 dummy signature // 保留原有 signature,或对 Gemini 模型使用 dummy signature
// Claude 模型不设置 signature(避免验证问题) if block.Signature != "" {
if allowDummyThought { part.ThoughtSignature = block.Signature
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature part.ThoughtSignature = dummyThoughtSignature
} }
parts = append(parts, part) parts = append(parts, part)
...@@ -433,9 +386,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { ...@@ -433,9 +386,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具 // 普通工具
var funcDecls []GeminiFunctionDecl var funcDecls []GeminiFunctionDecl
for i, tool := range tools { for _, tool := range tools {
// 跳过无效工具名称 // 跳过无效工具名称
if strings.TrimSpace(tool.Name) == "" { if tool.Name == "" {
log.Printf("Warning: skipping tool with empty name") log.Printf("Warning: skipping tool with empty name")
continue continue
} }
...@@ -444,18 +397,10 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { ...@@ -444,18 +397,10 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
var inputSchema map[string]any var inputSchema map[string]any
// 检查是否为 custom 类型工具 (MCP) // 检查是否为 custom 类型工具 (MCP)
if tool.Type == "custom" { if tool.Type == "custom" && tool.Custom != nil {
if tool.Custom == nil || tool.Custom.InputSchema == nil { // Custom 格式: 从 custom 字段获取 description 和 input_schema
log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name)
continue
}
description = tool.Custom.Description description = tool.Custom.Description
inputSchema = tool.Custom.InputSchema inputSchema = tool.Custom.InputSchema
// 调试日志:记录 custom 工具的 schema
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
}
} else { } else {
// 标准格式: 从顶层字段获取 // 标准格式: 从顶层字段获取
description = tool.Description description = tool.Description
...@@ -464,6 +409,7 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { ...@@ -464,6 +409,7 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 清理 JSON Schema // 清理 JSON Schema
params := cleanJSONSchema(inputSchema) params := cleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值 // 为 nil schema 提供默认值
if params == nil { if params == nil {
params = map[string]any{ params = map[string]any{
...@@ -472,11 +418,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { ...@@ -472,11 +418,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
} }
} }
// 调试日志:记录清理后的 schema
if paramsJSON, err := json.Marshal(params); err == nil {
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
}
funcDecls = append(funcDecls, GeminiFunctionDecl{ funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: tool.Name, Name: tool.Name,
Description: description, Description: description,
...@@ -538,54 +479,24 @@ func cleanJSONSchema(schema map[string]any) map[string]any { ...@@ -538,54 +479,24 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
} }
// excludedSchemaKeys 不支持的 schema 字段 // excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{ var excludedSchemaKeys = map[string]bool{
// 元 schema 字段
"$schema": true, "$schema": true,
"$id": true, "$id": true,
"$ref": true, "$ref": true,
"additionalProperties": true,
// 字符串验证(Gemini 不支持)
"minLength": true, "minLength": true,
"maxLength": true, "maxLength": true,
"pattern": true, "minItems": true,
"maxItems": true,
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段) "uniqueItems": true,
"minimum": true, "minimum": true,
"maximum": true, "maximum": true,
"exclusiveMinimum": true, "exclusiveMinimum": true,
"exclusiveMaximum": true, "exclusiveMaximum": true,
"multipleOf": true, "pattern": true,
"format": true,
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems": true,
"minItems": true,
"maxItems": true,
// 组合 schema(Gemini 不支持)
"oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties)
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true, "default": true,
"strict": true,
"const": true, "const": true,
"examples": true, "examples": true,
"deprecated": true, "deprecated": true,
...@@ -593,9 +504,6 @@ var excludedSchemaKeys = map[string]bool{ ...@@ -593,9 +504,6 @@ var excludedSchemaKeys = map[string]bool{
"writeOnly": true, "writeOnly": true,
"contentMediaType": true, "contentMediaType": true,
"contentEncoding": true, "contentEncoding": true,
// Claude 特有字段
"strict": true,
} }
// cleanSchemaValue 递归清理 schema 值 // cleanSchemaValue 递归清理 schema 值
...@@ -615,31 +523,6 @@ func cleanSchemaValue(value any) any { ...@@ -615,31 +523,6 @@ func cleanSchemaValue(value any) any {
continue continue
} }
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
} else {
// 如果是 schema 对象,转换为 false(更安全的默认值)
result[k] = false
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
}
continue
}
// 递归清理所有值 // 递归清理所有值
result[k] = cleanSchemaValue(val) result[k] = cleanSchemaValue(val)
} }
......
package antigravity
import (
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {
name string
content string
allowDummyThought bool
expectedParts int
description string
}{
{
name: "Claude model - skip thinking block without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
},
{
name: "Claude model - keep thinking block with signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 3, // 三个block都保留
description: "Claude模型应该保留有signature的thinking block",
},
{
name: "Gemini model - use dummy signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: true,
expectedParts: 3, // 三个block都保留,thinking使用dummy signature
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != tt.expectedParts {
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func TestBuildTools_CustomTypeTools(t *testing.T) {
tests := []struct {
name string
tools []ClaudeTool
expectedLen int
description string
}{
{
name: "Standard tool format",
tools: []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []ClaudeTool{
{
Type: "custom",
Name: "mcp_tool",
Custom: &ClaudeCustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"param": map[string]any{"type": "string"},
},
},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从Custom字段读取description和input_schema",
},
{
name: "Mixed standard and custom tools",
tools: []ClaudeTool{
{
Name: "standard_tool",
Description: "Standard tool",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "custom",
Name: "custom_tool",
Custom: &ClaudeCustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
},
},
expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations
description: "混合标准和custom工具应该都能正确转换",
},
{
name: "Invalid custom tool - nil Custom field",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
// Custom 为 nil
},
},
expectedLen: 0, // 应该被跳过
description: "Custom字段为nil的custom工具应该被跳过",
},
{
name: "Invalid custom tool - nil InputSchema",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
Custom: &ClaudeCustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},
},
},
expectedLen: 0, // 应该被跳过
description: "InputSchema为nil的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildTools(tt.tools)
if len(result) != tt.expectedLen {
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
}
// 验证function declarations存在
if len(result) > 0 && result[0].FunctionDeclarations != nil {
if len(result[0].FunctionDeclarations) != len(tt.tools) {
t.Errorf("%s: got %d function declarations, want %d",
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
}
}
})
}
}
...@@ -16,12 +16,6 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav ...@@ -16,12 +16,6 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头 // Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{ var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)", "User-Agent": "claude-cli/2.0.62 (external, cli)",
......
...@@ -2,9 +2,7 @@ package repository ...@@ -2,9 +2,7 @@ package repository
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
...@@ -29,8 +27,6 @@ const ( ...@@ -29,8 +27,6 @@ const (
userSlotKeyPrefix = "concurrency:user:" userSlotKeyPrefix = "concurrency:user:"
// 等待队列计数器格式: concurrency:wait:{userID} // 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:" waitQueueKeyPrefix = "concurrency:wait:"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:"
// 默认槽位过期时间(分钟),可通过配置覆盖 // 默认槽位过期时间(分钟),可通过配置覆盖
defaultSlotTTLMinutes = 15 defaultSlotTTLMinutes = 15
...@@ -119,29 +115,6 @@ var ( ...@@ -119,29 +115,6 @@ var (
return 1 return 1
`) `)
// incrementAccountWaitScript - account-level wait queue count
incrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
// decrementWaitScript - same as before // decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(` decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1]) local current = redis.call('GET', KEYS[1])
...@@ -150,78 +123,22 @@ var ( ...@@ -150,78 +123,22 @@ var (
end end
return 1 return 1
`) `)
// getAccountsLoadBatchScript - batch load query (read-only)
// ARGV[1] = slot TTL (seconds, retained for compatibility)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript = redis.NewScript(`
local result = {}
local i = 2
while i <= #ARGV do
local accountID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, accountID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
cleanupExpiredSlotsScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`)
) )
type concurrencyCache struct { type concurrencyCache struct {
rdb *redis.Client rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒) slotTTLSeconds int // 槽位过期时间(秒)
waitQueueTTLSeconds int // 等待队列过期时间(秒)
} }
// NewConcurrencyCache 创建并发控制缓存 // NewConcurrencyCache 创建并发控制缓存
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟 // slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache {
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
if slotTTLMinutes <= 0 { if slotTTLMinutes <= 0 {
slotTTLMinutes = defaultSlotTTLMinutes slotTTLMinutes = defaultSlotTTLMinutes
} }
if waitQueueTTLSeconds <= 0 {
waitQueueTTLSeconds = slotTTLMinutes * 60
}
return &concurrencyCache{ return &concurrencyCache{
rdb: rdb, rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60, slotTTLSeconds: slotTTLMinutes * 60,
waitQueueTTLSeconds: waitQueueTTLSeconds,
} }
} }
...@@ -238,10 +155,6 @@ func waitQueueKey(userID int64) string { ...@@ -238,10 +155,6 @@ func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
} }
func accountWaitKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
}
// Account slot operations // Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
...@@ -312,75 +225,3 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) ...@@ -312,75 +225,3 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err return err
} }
// Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
key := accountWaitKey(accountID)
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
key := accountWaitKey(accountID)
val, err := c.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
return 0, err
}
if errors.Is(err, redis.Nil) {
return 0, nil
}
return val, nil
}
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
if len(accounts) == 0 {
return map[int64]*service.AccountLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, acc := range accounts {
args = append(args, acc.ID, acc.MaxConcurrency)
}
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
if err != nil {
return nil, err
}
loadMap := make(map[int64]*service.AccountLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
}
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
loadMap[accountID] = &service.AccountLoadInfo{
AccountID: accountID,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,
}
}
return loadMap, nil
}
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
key := accountSlotKey(accountID)
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
return err
}
...@@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) { ...@@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) {
_ = rdb.Close() _ = rdb.Close()
}() }()
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache) cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache)
ctx := context.Background() ctx := context.Background()
for _, size := range []int{10, 100, 1000} { for _, size := range []int{10, 100, 1000} {
......
...@@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct { ...@@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func (s *ConcurrencyCacheSuite) SetupTest() { func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest() s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds())) s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes)
} }
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
...@@ -218,48 +218,6 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() { ...@@ -218,48 +218,6 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count") require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
} }
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
accountID := int64(30)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
require.False(s.T(), ok, "expected account wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL account waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected account wait count 1")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
accountID := int64(301)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0 // When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999) cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
...@@ -274,139 +232,6 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() { ...@@ -274,139 +232,6 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
require.Equal(s.T(), 0, cur) require.Equal(s.T(), 0, cur)
} }
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
// Setup: Create accounts with different load states
account1 := int64(100)
account2 := int64(101)
account3 := int64(102)
// Account 1: 2/3 slots used, 1 waiting
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 2: 1/2 slots used, 0 waiting
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts := []service.AccountWithConcurrency{
{ID: account1, MaxConcurrency: 3},
{ID: account2, MaxConcurrency: 2},
{ID: account3, MaxConcurrency: 1},
}
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
require.NoError(s.T(), err)
require.Len(s.T(), loadMap, 3)
// Verify account1: (2 + 1) / 3 = 100%
load1 := loadMap[account1]
require.NotNil(s.T(), load1)
require.Equal(s.T(), account1, load1.AccountID)
require.Equal(s.T(), 2, load1.CurrentConcurrency)
require.Equal(s.T(), 1, load1.WaitingCount)
require.Equal(s.T(), 100, load1.LoadRate)
// Verify account2: (1 + 0) / 2 = 50%
load2 := loadMap[account2]
require.NotNil(s.T(), load2)
require.Equal(s.T(), account2, load2.AccountID)
require.Equal(s.T(), 1, load2.CurrentConcurrency)
require.Equal(s.T(), 0, load2.WaitingCount)
require.Equal(s.T(), 50, load2.LoadRate)
// Verify account3: (0 + 0) / 1 = 0%
load3 := loadMap[account3]
require.NotNil(s.T(), load3)
require.Equal(s.T(), account3, load3.AccountID)
require.Equal(s.T(), 0, load3.CurrentConcurrency)
require.Equal(s.T(), 0, load3.WaitingCount)
require.Equal(s.T(), 0, load3.LoadRate)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
// Test with empty account list
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
require.NoError(s.T(), err)
require.Empty(s.T(), loadMap)
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
accountID := int64(200)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
// Acquire 3 slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Verify 3 slots exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, cur)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now := time.Now().Unix()
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
require.NoError(s.T(), err)
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
require.NoError(s.T(), err)
// Run cleanup
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify only 1 slot remains (req3)
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur)
// Verify req3 still exists
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Len(s.T(), members, 1)
require.Equal(s.T(), "req3", members[0])
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
accountID := int64(201)
// Acquire 2 fresh slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Run cleanup (should not remove anything)
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify both slots still exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 2, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) { func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite)) suite.Run(t, new(ConcurrencyCacheSuite))
} }
...@@ -15,14 +15,7 @@ import ( ...@@ -15,14 +15,7 @@ import (
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数 // ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
// 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景 // 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景
func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache { func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache {
waitTTLSeconds := int(cfg.Gateway.Scheduling.StickySessionWaitTimeout.Seconds()) return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes)
if cfg.Gateway.Scheduling.FallbackWaitTimeout > cfg.Gateway.Scheduling.StickySessionWaitTimeout {
waitTTLSeconds = int(cfg.Gateway.Scheduling.FallbackWaitTimeout.Seconds())
}
if waitTTLSeconds <= 0 {
waitTTLSeconds = cfg.Gateway.ConcurrencySlotTTLMinutes * 60
}
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
} }
// ProviderSet is the Wire provider set for all repositories // ProviderSet is the Wire provider set for all repositories
......
...@@ -358,15 +358,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -358,15 +358,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err) return nil, fmt.Errorf("transform request: %w", err)
} }
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
truncated := string(bodyJSON)
if len(truncated) > 2000 {
truncated = truncated[:2000] + "..."
}
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
}
// 构建上游 action // 构建上游 action
action := "generateContent" action := "generateContent"
if claudeReq.Stream { if claudeReq.Stream {
......
...@@ -18,11 +18,6 @@ type ConcurrencyCache interface { ...@@ -18,11 +18,6 @@ type ConcurrencyCache interface {
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
// 账号等待队列(账号级)
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
// 用户槽位管理 // 用户槽位管理
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID) // 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
...@@ -32,12 +27,6 @@ type ConcurrencyCache interface { ...@@ -32,12 +27,6 @@ type ConcurrencyCache interface {
// 等待队列计数(只在首次创建时设置 TTL) // 等待队列计数(只在首次创建时设置 TTL)
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error DecrementWaitCount(ctx context.Context, userID int64) error
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
} }
// generateRequestID generates a unique request ID for concurrency slot tracking // generateRequestID generates a unique request ID for concurrency slot tracking
...@@ -72,18 +61,6 @@ type AcquireResult struct { ...@@ -72,18 +61,6 @@ type AcquireResult struct {
ReleaseFunc func() // Must be called when done (typically via defer) ReleaseFunc func() // Must be called when done (typically via defer)
} }
type AccountWithConcurrency struct {
ID int64
MaxConcurrency int
}
type AccountLoadInfo struct {
AccountID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account. // AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout. // If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes. // Returns a release function that MUST be called when the request completes.
...@@ -200,42 +177,6 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6 ...@@ -200,42 +177,6 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
} }
} }
// IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil {
return true, nil
}
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
if err != nil {
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
return true, nil
}
return result, nil
}
// DecrementAccountWaitCount decrements the wait queue counter for an account.
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
if s.cache == nil {
return
}
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
}
}
// GetAccountWaitingCount gets current wait queue count for an account.
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetAccountWaitingCount(ctx, accountID)
}
// CalculateMaxWait calculates the maximum wait queue size for a user // CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots // maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int { func CalculateMaxWait(userConcurrency int) int {
...@@ -245,57 +186,6 @@ func CalculateMaxWait(userConcurrency int) int { ...@@ -245,57 +186,6 @@ func CalculateMaxWait(userConcurrency int) int {
return userConcurrency + defaultExtraWaitSlots return userConcurrency + defaultExtraWaitSlots
} }
// GetAccountsLoadBatch returns load info for multiple accounts.
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if s.cache == nil {
return map[int64]*AccountLoadInfo{}, nil
}
return s.cache.GetAccountsLoadBatch(ctx, accounts)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
if s.cache == nil {
return nil
}
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
}
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
return
}
runCleanup := func() {
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
accounts, err := accountRepo.ListSchedulable(listCtx)
cancel()
if err != nil {
log.Printf("Warning: list schedulable accounts failed: %v", err)
return
}
for _, account := range accounts {
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
accountCancel()
if err != nil {
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
}
}
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
runCleanup()
for range ticker.C {
runCleanup()
}
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count // Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
......
...@@ -261,34 +261,6 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t ...@@ -261,34 +261,6 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户") require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
} }
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
}
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户 // TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) { func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
ctx := context.Background() ctx := context.Background()
...@@ -604,32 +576,6 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) { ...@@ -604,32 +576,6 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
})
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) { t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
...@@ -837,160 +783,3 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) { ...@@ -837,160 +783,3 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
}) })
} }
} }
// mockConcurrencyService for testing
type mockConcurrencyService struct {
accountLoads map[int64]*AccountLoadInfo
accountWaitCounts map[int64]int
acquireResults map[int64]bool
}
func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if m.accountLoads == nil {
return map[int64]*AccountLoadInfo{}, nil
}
result := make(map[int64]*AccountLoadInfo)
for _, acc := range accounts {
if load, ok := m.accountLoads[acc.ID]; ok {
result[acc.ID] = load
} else {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
}
return result, nil
}
func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if m.accountWaitCounts == nil {
return 0, nil
}
return m.accountWaitCounts[accountID], nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // No concurrency service
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("排除账号-不选择被排除的账号", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
})
t.Run("无可用账号-返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
})
}
...@@ -13,14 +13,12 @@ import ( ...@@ -13,14 +13,12 @@ import (
"log" "log"
"net/http" "net/http"
"regexp" "regexp"
"sort"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -68,20 +66,6 @@ type GatewayCache interface { ...@@ -68,20 +66,6 @@ type GatewayCache interface {
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
} }
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
Timeout time.Duration
MaxWaiting int
}
type AccountSelectionResult struct {
Account *Account
Acquired bool
ReleaseFunc func()
WaitPlan *AccountWaitPlan // nil means no wait allowed
}
// ClaudeUsage 表示Claude API返回的usage信息 // ClaudeUsage 表示Claude API返回的usage信息
type ClaudeUsage struct { type ClaudeUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
...@@ -124,7 +108,6 @@ type GatewayService struct { ...@@ -124,7 +108,6 @@ type GatewayService struct {
identityService *IdentityService identityService *IdentityService
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
deferredService *DeferredService deferredService *DeferredService
concurrencyService *ConcurrencyService
} }
// NewGatewayService creates a new GatewayService // NewGatewayService creates a new GatewayService
...@@ -136,7 +119,6 @@ func NewGatewayService( ...@@ -136,7 +119,6 @@ func NewGatewayService(
userSubRepo UserSubscriptionRepository, userSubRepo UserSubscriptionRepository,
cache GatewayCache, cache GatewayCache,
cfg *config.Config, cfg *config.Config,
concurrencyService *ConcurrencyService,
billingService *BillingService, billingService *BillingService,
rateLimitService *RateLimitService, rateLimitService *RateLimitService,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
...@@ -152,7 +134,6 @@ func NewGatewayService( ...@@ -152,7 +134,6 @@ func NewGatewayService(
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
cache: cache, cache: cache,
cfg: cfg, cfg: cfg,
concurrencyService: concurrencyService,
billingService: billingService, billingService: billingService,
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
...@@ -202,14 +183,6 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { ...@@ -202,14 +183,6 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return "" return ""
} }
// BindStickySession sets session -> account binding with standard TTL.
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 {
return nil
}
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil { if parsed == nil {
return "" return ""
...@@ -359,354 +332,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -359,354 +332,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
} }
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
if err != nil {
return nil, err
}
preferOAuth := platform == PlatformGemini
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
}
_, excluded := excludedIDs[accountID]
return excluded
}
// ============ Layer 1: 粘性会话优先 ============
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulable() &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
// ============ Layer 2: 负载感知选择 ============
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
if isExcluded(acc.ID) {
continue
}
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
candidates = append(candidates, acc)
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
if preferOAuth && a.account.Type != b.account.Type {
return a.account.Type == AccountTypeOAuth
}
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
}
}
// ============ Layer 3: 兜底排队 ============
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
for _, acc := range candidates {
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return nil, errors.New("no available accounts")
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, true
}
}
return nil, false
}
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
}
return config.GatewaySchedulingConfig{
StickySessionMaxWaiting: 3,
StickySessionWaitTimeout: 45 * time.Second,
FallbackWaitTimeout: 30 * time.Second,
FallbackMaxWaiting: 100,
LoadBatchEnabled: true,
SlotCleanupInterval: 30 * time.Second,
}
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, true, nil
}
if groupID != nil {
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return "", false, fmt.Errorf("get group failed: %w", err)
}
return group.Platform, false, nil
}
return PlatformAnthropic, false, nil
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed {
platforms := []string{platform, PlatformAntigravity}
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
filtered = append(filtered, acc)
}
return filtered, useMixed, nil
}
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
if err == nil && len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
return nil, useMixed, err
}
return accounts, useMixed, nil
}
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
if account == nil {
return false
}
if useMixed {
if account.Platform == platform {
return true
}
return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()
}
return account.Platform == platform
}
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
}
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
switch {
case a.LastUsedAt == nil && b.LastUsedAt != nil:
return true
case a.LastUsedAt != nil && b.LastUsedAt == nil:
return false
case a.LastUsedAt == nil && b.LastUsedAt == nil:
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
default:
return a.LastUsedAt.Before(*b.LastUsedAt)
}
})
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) // selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
...@@ -762,9 +389,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -762,9 +389,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
case acc.LastUsedAt != nil && selected.LastUsedAt == nil: case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred) // keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil: case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { // keep selected (both never used)
selected = acc
}
default: default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) { if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc selected = acc
...@@ -794,7 +419,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -794,7 +419,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 // 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
platforms := []string{nativePlatform, PlatformAntigravity} platforms := []string{nativePlatform, PlatformAntigravity}
preferOAuth := nativePlatform == PlatformGemini
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" { if sessionHash != "" {
...@@ -854,9 +478,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -854,9 +478,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
case acc.LastUsedAt != nil && selected.LastUsedAt == nil: case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred) // keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil: case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { // keep selected (both never used)
selected = acc
}
default: default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) { if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc selected = acc
...@@ -1062,30 +684,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1062,30 +684,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理错误响应(不可重试的错误) // 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
if s.shouldFailoverOn400(respBody) {
if s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Account %d: 400 error, attempting failover: %s",
account.ID,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
} else {
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
}
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
} }
...@@ -1188,13 +786,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex ...@@ -1188,13 +786,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta header(OAuth账号需要特殊处理) // 处理anthropic-beta header(OAuth账号需要特殊处理)
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
} }
return req, nil return req, nil
...@@ -1247,83 +838,6 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) ...@@ -1247,83 +838,6 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
return claude.DefaultBetaHeader return claude.DefaultBetaHeader
} }
func requestNeedsBetaFeatures(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
return true
}
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
return true
}
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
}
if len(b) > maxBytes {
b = b[:maxBytes]
}
s := string(b)
// 保持一行,避免污染日志格式
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "\\r")
return s
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
if strings.Contains(msg, "anthropic-beta") ||
strings.Contains(msg, "beta feature") ||
strings.Contains(msg, "requires beta") {
return true
}
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
return true
}
return false
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
inner := strings.TrimSpace(m)
// 有些上游会把完整 JSON 作为字符串塞进 message
if strings.HasPrefix(inner, "{") {
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
return m
}
// 兜底:尝试顶层 message
return gjson.GetBytes(body, "message").String()
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
...@@ -1336,16 +850,6 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res ...@@ -1336,16 +850,6 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
switch resp.StatusCode { switch resp.StatusCode {
case 400: case 400:
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Upstream 400 error (account=%d platform=%s type=%s): %s",
account.ID,
account.Platform,
account.Type,
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
c.Data(http.StatusBadRequest, "application/json", body) c.Data(http.StatusBadRequest, "application/json", body)
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
case 401: case 401:
...@@ -1825,18 +1329,6 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -1825,18 +1329,6 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 标记账号状态(429/529等) // 标记账号状态(429/529等)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 记录上游错误摘要便于排障(不回显请求内容)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// 返回简化的错误响应 // 返回简化的错误响应
errMsg := "Upstream request failed" errMsg := "Upstream request failed"
switch resp.StatusCode { switch resp.StatusCode {
...@@ -1917,13 +1409,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con ...@@ -1917,13 +1409,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
} }
return req, nil return req, nil
......
...@@ -2278,13 +2278,11 @@ func convertClaudeToolsToGeminiTools(tools any) []any { ...@@ -2278,13 +2278,11 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
"properties": map[string]any{}, "properties": map[string]any{},
} }
} }
// 清理 JSON Schema
cleanedParams := cleanToolSchema(params)
funcDecls = append(funcDecls, map[string]any{ funcDecls = append(funcDecls, map[string]any{
"name": name, "name": name,
"description": desc, "description": desc,
"parameters": cleanedParams, "parameters": params,
}) })
} }
...@@ -2298,41 +2296,6 @@ func convertClaudeToolsToGeminiTools(tools any) []any { ...@@ -2298,41 +2296,6 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
} }
} }
// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段
func cleanToolSchema(schema any) any {
if schema == nil {
return nil
}
switch v := schema.(type) {
case map[string]any:
cleaned := make(map[string]any)
for key, value := range v {
// 跳过不支持的字段
if key == "$schema" || key == "$id" || key == "$ref" ||
key == "additionalProperties" || key == "minLength" ||
key == "maxLength" || key == "minItems" || key == "maxItems" {
continue
}
// 递归清理嵌套对象
cleaned[key] = cleanToolSchema(value)
}
// 规范化 type 字段为大写
if typeVal, ok := cleaned["type"].(string); ok {
cleaned["type"] = strings.ToUpper(typeVal)
}
return cleaned
case []any:
cleaned := make([]any, len(v))
for i, item := range v {
cleaned[i] = cleanToolSchema(item)
}
return cleaned
default:
return v
}
}
func convertClaudeGenerationConfig(req map[string]any) map[string]any { func convertClaudeGenerationConfig(req map[string]any) map[string]any {
out := make(map[string]any) out := make(map[string]any)
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 { if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment