Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
c5c12d4c
Unverified
Commit
c5c12d4c
authored
Dec 31, 2025
by
Wesley Liddick
Committed by
GitHub
Jan 01, 2026
Browse files
Revert "feat(gateway): 实现负载感知的账号调度优化 (#114)" (#117)
This reverts commit
8d252303
.
parent
8d252303
Changes
29
Expand all
Show whitespace changes
Inline
Side-by-side
backend/cmd/server/wire_gen.go
View file @
c5c12d4c
...
@@ -99,7 +99,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -99,7 +99,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityGatewayService
:=
service
.
NewAntigravityGatewayService
(
accountRepository
,
gatewayCache
,
antigravityTokenProvider
,
rateLimitService
,
httpUpstream
)
antigravityGatewayService
:=
service
.
NewAntigravityGatewayService
(
accountRepository
,
gatewayCache
,
antigravityTokenProvider
,
rateLimitService
,
httpUpstream
)
accountTestService
:=
service
.
NewAccountTestService
(
accountRepository
,
oAuthService
,
openAIOAuthService
,
geminiTokenProvider
,
antigravityGatewayService
,
httpUpstream
)
accountTestService
:=
service
.
NewAccountTestService
(
accountRepository
,
oAuthService
,
openAIOAuthService
,
geminiTokenProvider
,
antigravityGatewayService
,
httpUpstream
)
concurrencyCache
:=
repository
.
ProvideConcurrencyCache
(
redisClient
,
configConfig
)
concurrencyCache
:=
repository
.
ProvideConcurrencyCache
(
redisClient
,
configConfig
)
concurrencyService
:=
service
.
Provide
ConcurrencyService
(
concurrencyCache
,
accountRepository
,
configConfig
)
concurrencyService
:=
service
.
New
ConcurrencyService
(
concurrencyCache
)
crsSyncService
:=
service
.
NewCRSSyncService
(
accountRepository
,
proxyRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
)
crsSyncService
:=
service
.
NewCRSSyncService
(
accountRepository
,
proxyRepository
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
)
accountHandler
:=
admin
.
NewAccountHandler
(
adminService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
rateLimitService
,
accountUsageService
,
accountTestService
,
concurrencyService
,
crsSyncService
)
accountHandler
:=
admin
.
NewAccountHandler
(
adminService
,
oAuthService
,
openAIOAuthService
,
geminiOAuthService
,
rateLimitService
,
accountUsageService
,
accountTestService
,
concurrencyService
,
crsSyncService
)
oAuthHandler
:=
admin
.
NewOAuthHandler
(
oAuthService
)
oAuthHandler
:=
admin
.
NewOAuthHandler
(
oAuthService
)
...
@@ -127,10 +127,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
...
@@ -127,10 +127,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityService
:=
service
.
NewIdentityService
(
identityCache
)
identityService
:=
service
.
NewIdentityService
(
identityCache
)
timingWheelService
:=
service
.
ProvideTimingWheelService
()
timingWheelService
:=
service
.
ProvideTimingWheelService
()
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
deferredService
:=
service
.
ProvideDeferredService
(
accountRepository
,
timingWheelService
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
)
gatewayService
:=
service
.
NewGatewayService
(
accountRepository
,
groupRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
billingService
,
rateLimitService
,
billingCacheService
,
identityService
,
httpUpstream
,
deferredService
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
)
geminiMessagesCompatService
:=
service
.
NewGeminiMessagesCompatService
(
accountRepository
,
groupRepository
,
gatewayCache
,
geminiTokenProvider
,
rateLimitService
,
httpUpstream
,
antigravityGatewayService
)
gatewayHandler
:=
handler
.
NewGatewayHandler
(
gatewayService
,
geminiMessagesCompatService
,
antigravityGatewayService
,
userService
,
concurrencyService
,
billingCacheService
)
gatewayHandler
:=
handler
.
NewGatewayHandler
(
gatewayService
,
geminiMessagesCompatService
,
antigravityGatewayService
,
userService
,
concurrencyService
,
billingCacheService
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
concurrencyService
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
)
openAIGatewayService
:=
service
.
NewOpenAIGatewayService
(
accountRepository
,
usageLogRepository
,
userRepository
,
userSubscriptionRepository
,
gatewayCache
,
configConfig
,
billingService
,
rateLimitService
,
billingCacheService
,
httpUpstream
,
deferredService
)
openAIGatewayHandler
:=
handler
.
NewOpenAIGatewayHandler
(
openAIGatewayService
,
concurrencyService
,
billingCacheService
)
openAIGatewayHandler
:=
handler
.
NewOpenAIGatewayHandler
(
openAIGatewayService
,
concurrencyService
,
billingCacheService
)
handlerSettingHandler
:=
handler
.
ProvideSettingHandler
(
settingService
,
buildInfo
)
handlerSettingHandler
:=
handler
.
ProvideSettingHandler
(
settingService
,
buildInfo
)
handlers
:=
handler
.
ProvideHandlers
(
authHandler
,
userHandler
,
apiKeyHandler
,
usageHandler
,
redeemHandler
,
subscriptionHandler
,
adminHandlers
,
gatewayHandler
,
openAIGatewayHandler
,
handlerSettingHandler
)
handlers
:=
handler
.
ProvideHandlers
(
authHandler
,
userHandler
,
apiKeyHandler
,
usageHandler
,
redeemHandler
,
subscriptionHandler
,
adminHandlers
,
gatewayHandler
,
openAIGatewayHandler
,
handlerSettingHandler
)
...
...
backend/internal/config/config.go
View file @
c5c12d4c
...
@@ -3,7 +3,6 @@ package config
...
@@ -3,7 +3,6 @@ package config
import
(
import
(
"fmt"
"fmt"
"strings"
"strings"
"time"
"github.com/spf13/viper"
"github.com/spf13/viper"
)
)
...
@@ -120,37 +119,6 @@ type GatewayConfig struct {
...
@@ -120,37 +119,6 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes
int
`mapstructure:"concurrency_slot_ttl_minutes"`
ConcurrencySlotTTLMinutes
int
`mapstructure:"concurrency_slot_ttl_minutes"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody
bool
`mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
LogUpstreamErrorBodyMaxBytes
int
`mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey
bool
`mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400
bool
`mapstructure:"failover_on_400"`
// Scheduling: 账号调度相关配置
Scheduling
GatewaySchedulingConfig
`mapstructure:"scheduling"`
}
// GatewaySchedulingConfig accounts scheduling configuration.
type
GatewaySchedulingConfig
struct
{
// 粘性会话排队配置
StickySessionMaxWaiting
int
`mapstructure:"sticky_session_max_waiting"`
StickySessionWaitTimeout
time
.
Duration
`mapstructure:"sticky_session_wait_timeout"`
// 兜底排队配置
FallbackWaitTimeout
time
.
Duration
`mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting
int
`mapstructure:"fallback_max_waiting"`
// 负载计算
LoadBatchEnabled
bool
`mapstructure:"load_batch_enabled"`
// 过期槽位清理周期(0 表示禁用)
SlotCleanupInterval
time
.
Duration
`mapstructure:"slot_cleanup_interval"`
}
}
func
(
s
*
ServerConfig
)
Address
()
string
{
func
(
s
*
ServerConfig
)
Address
()
string
{
...
@@ -345,10 +313,6 @@ func setDefaults() {
...
@@ -345,10 +313,6 @@ func setDefaults() {
// Gateway
// Gateway
viper
.
SetDefault
(
"gateway.response_header_timeout"
,
300
)
// 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
viper
.
SetDefault
(
"gateway.response_header_timeout"
,
300
)
// 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
viper
.
SetDefault
(
"gateway.log_upstream_error_body"
,
false
)
viper
.
SetDefault
(
"gateway.log_upstream_error_body_max_bytes"
,
2048
)
viper
.
SetDefault
(
"gateway.inject_beta_for_apikey"
,
false
)
viper
.
SetDefault
(
"gateway.failover_on_400"
,
false
)
viper
.
SetDefault
(
"gateway.max_body_size"
,
int64
(
100
*
1024
*
1024
))
viper
.
SetDefault
(
"gateway.max_body_size"
,
int64
(
100
*
1024
*
1024
))
viper
.
SetDefault
(
"gateway.connection_pool_isolation"
,
ConnectionPoolIsolationAccountProxy
)
viper
.
SetDefault
(
"gateway.connection_pool_isolation"
,
ConnectionPoolIsolationAccountProxy
)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
...
@@ -359,12 +323,6 @@ func setDefaults() {
...
@@ -359,12 +323,6 @@ func setDefaults() {
viper
.
SetDefault
(
"gateway.max_upstream_clients"
,
5000
)
viper
.
SetDefault
(
"gateway.max_upstream_clients"
,
5000
)
viper
.
SetDefault
(
"gateway.client_idle_ttl_seconds"
,
900
)
viper
.
SetDefault
(
"gateway.client_idle_ttl_seconds"
,
900
)
viper
.
SetDefault
(
"gateway.concurrency_slot_ttl_minutes"
,
15
)
// 并发槽位过期时间(支持超长请求)
viper
.
SetDefault
(
"gateway.concurrency_slot_ttl_minutes"
,
15
)
// 并发槽位过期时间(支持超长请求)
viper
.
SetDefault
(
"gateway.scheduling.sticky_session_max_waiting"
,
3
)
viper
.
SetDefault
(
"gateway.scheduling.sticky_session_wait_timeout"
,
45
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.fallback_wait_timeout"
,
30
*
time
.
Second
)
viper
.
SetDefault
(
"gateway.scheduling.fallback_max_waiting"
,
100
)
viper
.
SetDefault
(
"gateway.scheduling.load_batch_enabled"
,
true
)
viper
.
SetDefault
(
"gateway.scheduling.slot_cleanup_interval"
,
30
*
time
.
Second
)
// TokenRefresh
// TokenRefresh
viper
.
SetDefault
(
"token_refresh.enabled"
,
true
)
viper
.
SetDefault
(
"token_refresh.enabled"
,
true
)
...
@@ -453,21 +411,6 @@ func (c *Config) Validate() error {
...
@@ -453,21 +411,6 @@ func (c *Config) Validate() error {
if
c
.
Gateway
.
ConcurrencySlotTTLMinutes
<=
0
{
if
c
.
Gateway
.
ConcurrencySlotTTLMinutes
<=
0
{
return
fmt
.
Errorf
(
"gateway.concurrency_slot_ttl_minutes must be positive"
)
return
fmt
.
Errorf
(
"gateway.concurrency_slot_ttl_minutes must be positive"
)
}
}
if
c
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
<=
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.sticky_session_max_waiting must be positive"
)
}
if
c
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
<=
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.sticky_session_wait_timeout must be positive"
)
}
if
c
.
Gateway
.
Scheduling
.
FallbackWaitTimeout
<=
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.fallback_wait_timeout must be positive"
)
}
if
c
.
Gateway
.
Scheduling
.
FallbackMaxWaiting
<=
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.fallback_max_waiting must be positive"
)
}
if
c
.
Gateway
.
Scheduling
.
SlotCleanupInterval
<
0
{
return
fmt
.
Errorf
(
"gateway.scheduling.slot_cleanup_interval must be non-negative"
)
}
return
nil
return
nil
}
}
...
...
backend/internal/config/config_test.go
View file @
c5c12d4c
package
config
package
config
import
(
import
"testing"
"testing"
"time"
"github.com/spf13/viper"
)
func
TestNormalizeRunMode
(
t
*
testing
.
T
)
{
func
TestNormalizeRunMode
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
tests
:=
[]
struct
{
...
@@ -26,45 +21,3 @@ func TestNormalizeRunMode(t *testing.T) {
...
@@ -26,45 +21,3 @@ func TestNormalizeRunMode(t *testing.T) {
}
}
}
}
}
}
func
TestLoadDefaultSchedulingConfig
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
if
cfg
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
!=
3
{
t
.
Fatalf
(
"StickySessionMaxWaiting = %d, want 3"
,
cfg
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
)
}
if
cfg
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
!=
45
*
time
.
Second
{
t
.
Fatalf
(
"StickySessionWaitTimeout = %v, want 45s"
,
cfg
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
)
}
if
cfg
.
Gateway
.
Scheduling
.
FallbackWaitTimeout
!=
30
*
time
.
Second
{
t
.
Fatalf
(
"FallbackWaitTimeout = %v, want 30s"
,
cfg
.
Gateway
.
Scheduling
.
FallbackWaitTimeout
)
}
if
cfg
.
Gateway
.
Scheduling
.
FallbackMaxWaiting
!=
100
{
t
.
Fatalf
(
"FallbackMaxWaiting = %d, want 100"
,
cfg
.
Gateway
.
Scheduling
.
FallbackMaxWaiting
)
}
if
!
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
{
t
.
Fatalf
(
"LoadBatchEnabled = false, want true"
)
}
if
cfg
.
Gateway
.
Scheduling
.
SlotCleanupInterval
!=
30
*
time
.
Second
{
t
.
Fatalf
(
"SlotCleanupInterval = %v, want 30s"
,
cfg
.
Gateway
.
Scheduling
.
SlotCleanupInterval
)
}
}
func
TestLoadSchedulingConfigFromEnv
(
t
*
testing
.
T
)
{
viper
.
Reset
()
t
.
Setenv
(
"GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING"
,
"5"
)
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
if
cfg
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
!=
5
{
t
.
Fatalf
(
"StickySessionMaxWaiting = %d, want 5"
,
cfg
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
)
}
}
backend/internal/handler/gateway_handler.go
View file @
c5c12d4c
...
@@ -141,10 +141,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -141,10 +141,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
else
if
apiKey
.
Group
!=
nil
{
}
else
if
apiKey
.
Group
!=
nil
{
platform
=
apiKey
.
Group
.
Platform
platform
=
apiKey
.
Group
.
Platform
}
}
sessionKey
:=
sessionHash
if
platform
==
service
.
PlatformGemini
&&
sessionHash
!=
""
{
sessionKey
=
"gemini:"
+
sessionHash
}
if
platform
==
service
.
PlatformGemini
{
if
platform
==
service
.
PlatformGemini
{
const
maxAccountSwitches
=
3
const
maxAccountSwitches
=
3
...
@@ -153,7 +149,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -153,7 +149,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus
:=
0
lastFailoverStatus
:=
0
for
{
for
{
selection
,
err
:=
h
.
g
ateway
Service
.
SelectAccount
WithLoadAwarenes
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Key
,
reqModel
,
failedAccountIDs
)
account
,
err
:=
h
.
g
eminiCompat
Service
.
SelectAccount
ForModelWithExclusion
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Hash
,
reqModel
,
failedAccountIDs
)
if
err
!=
nil
{
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
@@ -162,13 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -162,13 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
return
return
}
}
account
:=
selection
.
Account
// 检查预热请求拦截(在账号选择后、转发前检查)
// 检查预热请求拦截(在账号选择后、转发前检查)
if
account
.
IsInterceptWarmupEnabled
()
&&
isWarmupRequest
(
body
)
{
if
account
.
IsInterceptWarmupEnabled
()
&&
isWarmupRequest
(
body
)
{
if
selection
.
Acquired
&&
selection
.
ReleaseFunc
!=
nil
{
selection
.
ReleaseFunc
()
}
if
reqStream
{
if
reqStream
{
sendMockWarmupStream
(
c
,
reqModel
)
sendMockWarmupStream
(
c
,
reqModel
)
}
else
{
}
else
{
...
@@ -178,47 +170,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -178,47 +170,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
// 3. 获取账号并发槽位
// 3. 获取账号并发槽位
accountReleaseFunc
:=
selection
.
ReleaseFunc
accountReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWait
(
c
,
account
.
ID
,
account
.
Concurrency
,
reqStream
,
&
streamStarted
)
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
reqStream
,
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
return
}
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 转发请求 - 根据账号平台分流
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
var
result
*
service
.
ForwardResult
...
@@ -230,9 +187,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -230,9 +187,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
accountReleaseFunc
!=
nil
{
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
accountReleaseFunc
()
}
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
@@ -277,7 +231,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -277,7 +231,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for
{
for
{
// 选择支持该模型的账号
// 选择支持该模型的账号
selection
,
err
:=
h
.
gatewayService
.
SelectAccount
WithLoadAwarenes
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Key
,
reqModel
,
failedAccountIDs
)
account
,
err
:=
h
.
gatewayService
.
SelectAccount
ForModelWithExclusion
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Hash
,
reqModel
,
failedAccountIDs
)
if
err
!=
nil
{
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
@@ -286,13 +240,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -286,13 +240,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
return
return
}
}
account
:=
selection
.
Account
// 检查预热请求拦截(在账号选择后、转发前检查)
// 检查预热请求拦截(在账号选择后、转发前检查)
if
account
.
IsInterceptWarmupEnabled
()
&&
isWarmupRequest
(
body
)
{
if
account
.
IsInterceptWarmupEnabled
()
&&
isWarmupRequest
(
body
)
{
if
selection
.
Acquired
&&
selection
.
ReleaseFunc
!=
nil
{
selection
.
ReleaseFunc
()
}
if
reqStream
{
if
reqStream
{
sendMockWarmupStream
(
c
,
reqModel
)
sendMockWarmupStream
(
c
,
reqModel
)
}
else
{
}
else
{
...
@@ -302,47 +252,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -302,47 +252,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
}
// 3. 获取账号并发槽位
// 3. 获取账号并发槽位
accountReleaseFunc
:=
selection
.
ReleaseFunc
accountReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWait
(
c
,
account
.
ID
,
account
.
Concurrency
,
reqStream
,
&
streamStarted
)
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
reqStream
,
&
streamStarted
,
)
if
err
!=
nil
{
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
return
}
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 转发请求 - 根据账号平台分流
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
var
result
*
service
.
ForwardResult
...
@@ -354,9 +269,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -354,9 +269,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
accountReleaseFunc
!=
nil
{
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
accountReleaseFunc
()
}
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
...
backend/internal/handler/gateway_helper.go
View file @
c5c12d4c
...
@@ -83,16 +83,6 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
...
@@ -83,16 +83,6 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
h
.
concurrencyService
.
DecrementWaitCount
(
ctx
,
userID
)
h
.
concurrencyService
.
DecrementWaitCount
(
ctx
,
userID
)
}
}
// IncrementAccountWaitCount increments the wait count for an account
func
(
h
*
ConcurrencyHelper
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
h
.
concurrencyService
.
IncrementAccountWaitCount
(
ctx
,
accountID
,
maxWait
)
}
// DecrementAccountWaitCount decrements the wait count for an account
func
(
h
*
ConcurrencyHelper
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
{
h
.
concurrencyService
.
DecrementAccountWaitCount
(
ctx
,
accountID
)
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
// streamStarted is updated if streaming response has begun.
...
@@ -136,12 +126,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
...
@@ -136,12 +126,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPing
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPing
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
return
h
.
waitForSlotWithPingTimeout
(
c
,
slotType
,
id
,
maxConcurrency
,
maxConcurrencyWait
,
isStream
,
streamStarted
)
ctx
,
cancel
:=
context
.
WithTimeout
(
c
.
Request
.
Context
(),
maxConcurrencyWait
)
}
// waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout.
func
(
h
*
ConcurrencyHelper
)
waitForSlotWithPingTimeout
(
c
*
gin
.
Context
,
slotType
string
,
id
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
c
.
Request
.
Context
(),
timeout
)
defer
cancel
()
defer
cancel
()
// Determine if ping is needed (streaming + ping format defined)
// Determine if ping is needed (streaming + ping format defined)
...
@@ -215,11 +200,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
...
@@ -215,11 +200,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
}
}
}
}
// AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping).
func
(
h
*
ConcurrencyHelper
)
AcquireAccountSlotWithWaitTimeout
(
c
*
gin
.
Context
,
accountID
int64
,
maxConcurrency
int
,
timeout
time
.
Duration
,
isStream
bool
,
streamStarted
*
bool
)
(
func
(),
error
)
{
return
h
.
waitForSlotWithPingTimeout
(
c
,
"account"
,
accountID
,
maxConcurrency
,
timeout
,
isStream
,
streamStarted
)
}
// nextBackoff 计算下一次退避时间
// nextBackoff 计算下一次退避时间
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// 性能优化:使用指数退避 + 随机抖动,避免惊群效应
// current: 当前退避时间
// current: 当前退避时间
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
c5c12d4c
...
@@ -197,17 +197,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -197,17 +197,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body)
// 3) select account (sticky session based on request body)
parsedReq
,
_
:=
service
.
ParseGatewayRequest
(
body
)
parsedReq
,
_
:=
service
.
ParseGatewayRequest
(
body
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
sessionKey
:=
sessionHash
if
sessionHash
!=
""
{
sessionKey
=
"gemini:"
+
sessionHash
}
const
maxAccountSwitches
=
3
const
maxAccountSwitches
=
3
switchCount
:=
0
switchCount
:=
0
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
failedAccountIDs
:=
make
(
map
[
int64
]
struct
{})
lastFailoverStatus
:=
0
lastFailoverStatus
:=
0
for
{
for
{
selection
,
err
:=
h
.
g
ateway
Service
.
SelectAccount
WithLoadAwarenes
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Key
,
modelName
,
failedAccountIDs
)
account
,
err
:=
h
.
g
eminiCompat
Service
.
SelectAccount
ForModelWithExclusion
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
session
Hash
,
modelName
,
failedAccountIDs
)
if
err
!=
nil
{
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
if
len
(
failedAccountIDs
)
==
0
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
...
@@ -216,49 +212,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -216,49 +212,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
handleGeminiFailoverExhausted
(
c
,
lastFailoverStatus
)
handleGeminiFailoverExhausted
(
c
,
lastFailoverStatus
)
return
return
}
}
account
:=
selection
.
Account
// 4) account concurrency slot
// 4) account concurrency slot
accountReleaseFunc
:=
selection
.
ReleaseFunc
accountReleaseFunc
,
err
:=
geminiConcurrency
.
AcquireAccountSlotWithWait
(
c
,
account
.
ID
,
account
.
Concurrency
,
stream
,
&
streamStarted
)
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts"
)
return
}
canWait
,
err
:=
geminiConcurrency
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
googleError
(
c
,
http
.
StatusTooManyRequests
,
"Too many pending requests, please retry later"
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
geminiConcurrency
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}
accountReleaseFunc
,
err
=
geminiConcurrency
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
stream
,
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
googleError
(
c
,
http
.
StatusTooManyRequests
,
err
.
Error
())
googleError
(
c
,
http
.
StatusTooManyRequests
,
err
.
Error
())
return
return
}
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 5) forward (根据平台分流)
// 5) forward (根据平台分流)
var
result
*
service
.
ForwardResult
var
result
*
service
.
ForwardResult
...
@@ -270,9 +230,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -270,9 +230,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if
accountReleaseFunc
!=
nil
{
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
accountReleaseFunc
()
}
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
c5c12d4c
...
@@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
...
@@ -146,7 +146,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
for
{
for
{
// Select account supporting the requested model
// Select account supporting the requested model
log
.
Printf
(
"[OpenAI Handler] Selecting account: groupID=%v model=%s"
,
apiKey
.
GroupID
,
reqModel
)
log
.
Printf
(
"[OpenAI Handler] Selecting account: groupID=%v model=%s"
,
apiKey
.
GroupID
,
reqModel
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccount
WithLoadAwarenes
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
failedAccountIDs
)
account
,
err
:=
h
.
gatewayService
.
SelectAccount
ForModelWithExclusion
s
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionHash
,
reqModel
,
failedAccountIDs
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[OpenAI Handler] SelectAccount failed: %v"
,
err
)
log
.
Printf
(
"[OpenAI Handler] SelectAccount failed: %v"
,
err
)
if
len
(
failedAccountIDs
)
==
0
{
if
len
(
failedAccountIDs
)
==
0
{
...
@@ -156,60 +156,21 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
...
@@ -156,60 +156,21 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
return
return
}
}
account
:=
selection
.
Account
log
.
Printf
(
"[OpenAI Handler] Selected account: id=%d name=%s"
,
account
.
ID
,
account
.
Name
)
log
.
Printf
(
"[OpenAI Handler] Selected account: id=%d name=%s"
,
account
.
ID
,
account
.
Name
)
// 3. Acquire account concurrency slot
// 3. Acquire account concurrency slot
accountReleaseFunc
:=
selection
.
ReleaseFunc
accountReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWait
(
c
,
account
.
ID
,
account
.
Concurrency
,
reqStream
,
&
streamStarted
)
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
account
.
ID
,
selection
.
WaitPlan
.
MaxConcurrency
,
selection
.
WaitPlan
.
Timeout
,
reqStream
,
&
streamStarted
,
)
if
err
!=
nil
{
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
return
}
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
sessionHash
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// Forward request
// Forward request
result
,
err
:=
h
.
gatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
result
,
err
:=
h
.
gatewayService
.
Forward
(
c
.
Request
.
Context
(),
c
,
account
,
body
)
if
accountReleaseFunc
!=
nil
{
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
accountReleaseFunc
()
}
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
...
backend/internal/pkg/antigravity/claude_types.go
View file @
c5c12d4c
...
@@ -54,9 +54,6 @@ type CustomToolSpec struct {
...
@@ -54,9 +54,6 @@ type CustomToolSpec struct {
InputSchema
map
[
string
]
any
`json:"input_schema"`
InputSchema
map
[
string
]
any
`json:"input_schema"`
}
}
// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格)
type
ClaudeCustomToolSpec
=
CustomToolSpec
// SystemBlock system prompt 数组形式的元素
// SystemBlock system prompt 数组形式的元素
type
SystemBlock
struct
{
type
SystemBlock
struct
{
Type
string
`json:"type"`
Type
string
`json:"type"`
...
...
backend/internal/pkg/antigravity/request_transformer.go
View file @
c5c12d4c
...
@@ -14,16 +14,13 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
...
@@ -14,16 +14,13 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// 用于存储 tool_use id -> name 映射
// 用于存储 tool_use id -> name 映射
toolIDToName
:=
make
(
map
[
string
]
string
)
toolIDToName
:=
make
(
map
[
string
]
string
)
// 检测是否启用 thinking
isThinkingEnabled
:=
claudeReq
.
Thinking
!=
nil
&&
claudeReq
.
Thinking
.
Type
==
"enabled"
// 只有 Gemini 模型支持 dummy thought workaround
// 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought
:=
strings
.
HasPrefix
(
mappedModel
,
"gemini-"
)
allowDummyThought
:=
strings
.
HasPrefix
(
mappedModel
,
"gemini-"
)
// 检测是否启用 thinking
requestedThinkingEnabled
:=
claudeReq
.
Thinking
!=
nil
&&
claudeReq
.
Thinking
.
Type
==
"enabled"
// 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等),
// 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。
isThinkingEnabled
:=
requestedThinkingEnabled
&&
allowDummyThought
// 1. 构建 contents
// 1. 构建 contents
contents
,
err
:=
buildContents
(
claudeReq
.
Messages
,
toolIDToName
,
isThinkingEnabled
,
allowDummyThought
)
contents
,
err
:=
buildContents
(
claudeReq
.
Messages
,
toolIDToName
,
isThinkingEnabled
,
allowDummyThought
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -34,15 +31,7 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
...
@@ -34,15 +31,7 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
systemInstruction
:=
buildSystemInstruction
(
claudeReq
.
System
,
claudeReq
.
Model
)
systemInstruction
:=
buildSystemInstruction
(
claudeReq
.
System
,
claudeReq
.
Model
)
// 3. 构建 generationConfig
// 3. 构建 generationConfig
reqForGen
:=
claudeReq
generationConfig
:=
buildGenerationConfig
(
claudeReq
)
if
requestedThinkingEnabled
&&
!
allowDummyThought
{
log
.
Printf
(
"[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s"
,
mappedModel
)
// shallow copy to avoid mutating caller's request
clone
:=
*
claudeReq
clone
.
Thinking
=
nil
reqForGen
=
&
clone
}
generationConfig
:=
buildGenerationConfig
(
reqForGen
)
// 4. 构建 tools
// 4. 构建 tools
tools
:=
buildTools
(
claudeReq
.
Tools
)
tools
:=
buildTools
(
claudeReq
.
Tools
)
...
@@ -161,7 +150,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
...
@@ -161,7 +150,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
parts
=
append
([]
GeminiPart
{{
parts
=
append
([]
GeminiPart
{{
Text
:
"Thinking..."
,
Text
:
"Thinking..."
,
Thought
:
true
,
Thought
:
true
,
ThoughtSignature
:
dummyThoughtSignature
,
}},
parts
...
)
}},
parts
...
)
}
}
}
}
...
@@ -183,34 +171,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
...
@@ -183,34 +171,6 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const
dummyThoughtSignature
=
"skip_thought_signature_validator"
const
dummyThoughtSignature
=
"skip_thought_signature_validator"
// isValidThoughtSignature 验证 thought signature 是否有效
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
func
isValidThoughtSignature
(
signature
string
)
bool
{
// 空字符串无效
if
signature
==
""
{
return
false
}
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
// 参考 Claude API 文档和实际观察到的有效 signature
if
len
(
signature
)
<
40
{
log
.
Printf
(
"[Debug] Signature too short: len=%d"
,
len
(
signature
))
return
false
}
// 检查是否是有效的 base64 字符
// base64 字符集: A-Z, a-z, 0-9, +, /, =
for
i
,
c
:=
range
signature
{
if
(
c
<
'A'
||
c
>
'Z'
)
&&
(
c
<
'a'
||
c
>
'z'
)
&&
(
c
<
'0'
||
c
>
'9'
)
&&
c
!=
'+'
&&
c
!=
'/'
&&
c
!=
'='
{
log
.
Printf
(
"[Debug] Invalid base64 character at position %d: %c (code=%d)"
,
i
,
c
,
c
)
return
false
}
}
return
true
}
// buildParts 构建消息的 parts
// buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func
buildParts
(
content
json
.
RawMessage
,
toolIDToName
map
[
string
]
string
,
allowDummyThought
bool
)
([]
GeminiPart
,
error
)
{
func
buildParts
(
content
json
.
RawMessage
,
toolIDToName
map
[
string
]
string
,
allowDummyThought
bool
)
([]
GeminiPart
,
error
)
{
...
@@ -239,30 +199,22 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
...
@@ -239,30 +199,22 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
}
}
case
"thinking"
:
case
"thinking"
:
if
allowDummyThought
{
part
:=
GeminiPart
{
// Gemini 模型可以使用 dummy signature
parts
=
append
(
parts
,
GeminiPart
{
Text
:
block
.
Thinking
,
Text
:
block
.
Thinking
,
Thought
:
true
,
Thought
:
true
,
ThoughtSignature
:
dummyThoughtSignature
,
})
continue
}
}
// 保留原有 signature(Claude 模型需要有效的 signature)
// Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。
if
block
.
Signature
!=
""
{
signature
:=
strings
.
TrimSpace
(
block
.
Signature
)
part
.
ThoughtSignature
=
block
.
Signature
if
signature
==
""
||
signature
==
dummyThoughtSignature
{
}
else
if
!
allowDummyThought
{
log
.
Printf
(
"[Warning] Skipping thinking block for Claude model (missing or dummy signature)"
)
// Claude 模型需要有效 signature,跳过无 signature 的 thinking block
log
.
Printf
(
"Warning: skipping thinking block without signature for Claude model"
)
continue
continue
}
else
{
// Gemini 模型使用 dummy signature
part
.
ThoughtSignature
=
dummyThoughtSignature
}
}
if
!
isValidThoughtSignature
(
signature
)
{
parts
=
append
(
parts
,
part
)
log
.
Printf
(
"[Debug] Thinking signature may be invalid (passing through anyway): len=%d"
,
len
(
signature
))
}
parts
=
append
(
parts
,
GeminiPart
{
Text
:
block
.
Thinking
,
Thought
:
true
,
ThoughtSignature
:
signature
,
})
case
"image"
:
case
"image"
:
if
block
.
Source
!=
nil
&&
block
.
Source
.
Type
==
"base64"
{
if
block
.
Source
!=
nil
&&
block
.
Source
.
Type
==
"base64"
{
...
@@ -287,9 +239,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
...
@@ -287,9 +239,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
ID
:
block
.
ID
,
ID
:
block
.
ID
,
},
},
}
}
// 只有 Gemini 模型使用 dummy signature
// 保留原有 signature,或对 Gemini 模型使用 dummy signature
// Claude 模型不设置 signature(避免验证问题)
if
block
.
Signature
!=
""
{
if
allowDummyThought
{
part
.
ThoughtSignature
=
block
.
Signature
}
else
if
allowDummyThought
{
part
.
ThoughtSignature
=
dummyThoughtSignature
part
.
ThoughtSignature
=
dummyThoughtSignature
}
}
parts
=
append
(
parts
,
part
)
parts
=
append
(
parts
,
part
)
...
@@ -433,9 +386,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
...
@@ -433,9 +386,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具
// 普通工具
var
funcDecls
[]
GeminiFunctionDecl
var
funcDecls
[]
GeminiFunctionDecl
for
i
,
tool
:=
range
tools
{
for
_
,
tool
:=
range
tools
{
// 跳过无效工具名称
// 跳过无效工具名称
if
strings
.
TrimSpace
(
tool
.
Name
)
==
""
{
if
tool
.
Name
==
""
{
log
.
Printf
(
"Warning: skipping tool with empty name"
)
log
.
Printf
(
"Warning: skipping tool with empty name"
)
continue
continue
}
}
...
@@ -444,18 +397,10 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
...
@@ -444,18 +397,10 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
var
inputSchema
map
[
string
]
any
var
inputSchema
map
[
string
]
any
// 检查是否为 custom 类型工具 (MCP)
// 检查是否为 custom 类型工具 (MCP)
if
tool
.
Type
==
"custom"
{
if
tool
.
Type
==
"custom"
&&
tool
.
Custom
!=
nil
{
if
tool
.
Custom
==
nil
||
tool
.
Custom
.
InputSchema
==
nil
{
// Custom 格式: 从 custom 字段获取 description 和 input_schema
log
.
Printf
(
"[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema"
,
tool
.
Name
)
continue
}
description
=
tool
.
Custom
.
Description
description
=
tool
.
Custom
.
Description
inputSchema
=
tool
.
Custom
.
InputSchema
inputSchema
=
tool
.
Custom
.
InputSchema
// 调试日志:记录 custom 工具的 schema
if
schemaJSON
,
err
:=
json
.
Marshal
(
inputSchema
);
err
==
nil
{
log
.
Printf
(
"[Debug] Tool[%d] '%s' (custom) original schema: %s"
,
i
,
tool
.
Name
,
string
(
schemaJSON
))
}
}
else
{
}
else
{
// 标准格式: 从顶层字段获取
// 标准格式: 从顶层字段获取
description
=
tool
.
Description
description
=
tool
.
Description
...
@@ -464,6 +409,7 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
...
@@ -464,6 +409,7 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 清理 JSON Schema
// 清理 JSON Schema
params
:=
cleanJSONSchema
(
inputSchema
)
params
:=
cleanJSONSchema
(
inputSchema
)
// 为 nil schema 提供默认值
// 为 nil schema 提供默认值
if
params
==
nil
{
if
params
==
nil
{
params
=
map
[
string
]
any
{
params
=
map
[
string
]
any
{
...
@@ -472,11 +418,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
...
@@ -472,11 +418,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
}
}
}
// 调试日志:记录清理后的 schema
if
paramsJSON
,
err
:=
json
.
Marshal
(
params
);
err
==
nil
{
log
.
Printf
(
"[Debug] Tool[%d] '%s' cleaned schema: %s"
,
i
,
tool
.
Name
,
string
(
paramsJSON
))
}
funcDecls
=
append
(
funcDecls
,
GeminiFunctionDecl
{
funcDecls
=
append
(
funcDecls
,
GeminiFunctionDecl
{
Name
:
tool
.
Name
,
Name
:
tool
.
Name
,
Description
:
description
,
Description
:
description
,
...
@@ -538,54 +479,24 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
...
@@ -538,54 +479,24 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
}
}
// excludedSchemaKeys 不支持的 schema 字段
// excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var
excludedSchemaKeys
=
map
[
string
]
bool
{
var
excludedSchemaKeys
=
map
[
string
]
bool
{
// 元 schema 字段
"$schema"
:
true
,
"$schema"
:
true
,
"$id"
:
true
,
"$id"
:
true
,
"$ref"
:
true
,
"$ref"
:
true
,
"additionalProperties"
:
true
,
// 字符串验证(Gemini 不支持)
"minLength"
:
true
,
"minLength"
:
true
,
"maxLength"
:
true
,
"maxLength"
:
true
,
"
pattern"
:
true
,
"
minItems"
:
true
,
"maxItems"
:
true
,
// 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems"
:
true
,
"minimum"
:
true
,
"minimum"
:
true
,
"maximum"
:
true
,
"maximum"
:
true
,
"exclusiveMinimum"
:
true
,
"exclusiveMinimum"
:
true
,
"exclusiveMaximum"
:
true
,
"exclusiveMaximum"
:
true
,
"multipleOf"
:
true
,
"pattern"
:
true
,
"format"
:
true
,
// 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
"uniqueItems"
:
true
,
"minItems"
:
true
,
"maxItems"
:
true
,
// 组合 schema(Gemini 不支持)
"oneOf"
:
true
,
"anyOf"
:
true
,
"allOf"
:
true
,
"not"
:
true
,
"if"
:
true
,
"then"
:
true
,
"else"
:
true
,
"$defs"
:
true
,
"definitions"
:
true
,
// 对象验证(仅保留 properties/required/additionalProperties)
"minProperties"
:
true
,
"maxProperties"
:
true
,
"patternProperties"
:
true
,
"propertyNames"
:
true
,
"dependencies"
:
true
,
"dependentSchemas"
:
true
,
"dependentRequired"
:
true
,
// 其他不支持的字段
"default"
:
true
,
"default"
:
true
,
"strict"
:
true
,
"const"
:
true
,
"const"
:
true
,
"examples"
:
true
,
"examples"
:
true
,
"deprecated"
:
true
,
"deprecated"
:
true
,
...
@@ -593,9 +504,6 @@ var excludedSchemaKeys = map[string]bool{
...
@@ -593,9 +504,6 @@ var excludedSchemaKeys = map[string]bool{
"writeOnly"
:
true
,
"writeOnly"
:
true
,
"contentMediaType"
:
true
,
"contentMediaType"
:
true
,
"contentEncoding"
:
true
,
"contentEncoding"
:
true
,
// Claude 特有字段
"strict"
:
true
,
}
}
// cleanSchemaValue 递归清理 schema 值
// cleanSchemaValue 递归清理 schema 值
...
@@ -615,31 +523,6 @@ func cleanSchemaValue(value any) any {
...
@@ -615,31 +523,6 @@ func cleanSchemaValue(value any) any {
continue
continue
}
}
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if
k
==
"format"
{
if
formatStr
,
ok
:=
val
.
(
string
);
ok
{
// Gemini 只支持 date-time, date, time
if
formatStr
==
"date-time"
||
formatStr
==
"date"
||
formatStr
==
"time"
{
result
[
k
]
=
val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
if
k
==
"additionalProperties"
{
if
boolVal
,
ok
:=
val
.
(
bool
);
ok
{
result
[
k
]
=
boolVal
log
.
Printf
(
"[Debug] additionalProperties is bool: %v"
,
boolVal
)
}
else
{
// 如果是 schema 对象,转换为 false(更安全的默认值)
result
[
k
]
=
false
log
.
Printf
(
"[Debug] additionalProperties is not bool (type: %T), converting to false"
,
val
)
}
continue
}
// 递归清理所有值
// 递归清理所有值
result
[
k
]
=
cleanSchemaValue
(
val
)
result
[
k
]
=
cleanSchemaValue
(
val
)
}
}
...
...
backend/internal/pkg/antigravity/request_transformer_test.go
deleted
100644 → 0
View file @
8d252303
package
antigravity
import
(
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func
TestBuildParts_ThinkingBlockWithoutSignature
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
content
string
allowDummyThought
bool
expectedParts
int
description
string
}{
{
name
:
"Claude model - skip thinking block without signature"
,
content
:
`[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`
,
allowDummyThought
:
false
,
expectedParts
:
2
,
// 只有两个text block
description
:
"Claude模型应该跳过无signature的thinking block"
,
},
{
name
:
"Claude model - keep thinking block with signature"
,
content
:
`[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`
,
allowDummyThought
:
false
,
expectedParts
:
3
,
// 三个block都保留
description
:
"Claude模型应该保留有signature的thinking block"
,
},
{
name
:
"Gemini model - use dummy signature"
,
content
:
`[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`
,
allowDummyThought
:
true
,
expectedParts
:
3
,
// 三个block都保留,thinking使用dummy signature
description
:
"Gemini模型应该为无signature的thinking block使用dummy signature"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
toolIDToName
:=
make
(
map
[
string
]
string
)
parts
,
err
:=
buildParts
(
json
.
RawMessage
(
tt
.
content
),
toolIDToName
,
tt
.
allowDummyThought
)
if
err
!=
nil
{
t
.
Fatalf
(
"buildParts() error = %v"
,
err
)
}
if
len
(
parts
)
!=
tt
.
expectedParts
{
t
.
Errorf
(
"%s: got %d parts, want %d parts"
,
tt
.
description
,
len
(
parts
),
tt
.
expectedParts
)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func
TestBuildTools_CustomTypeTools
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
tools
[]
ClaudeTool
expectedLen
int
description
string
}{
{
name
:
"Standard tool format"
,
tools
:
[]
ClaudeTool
{
{
Name
:
"get_weather"
,
Description
:
"Get weather information"
,
InputSchema
:
map
[
string
]
any
{
"type"
:
"object"
,
"properties"
:
map
[
string
]
any
{
"location"
:
map
[
string
]
any
{
"type"
:
"string"
},
},
},
},
},
expectedLen
:
1
,
description
:
"标准工具格式应该正常转换"
,
},
{
name
:
"Custom type tool (MCP format)"
,
tools
:
[]
ClaudeTool
{
{
Type
:
"custom"
,
Name
:
"mcp_tool"
,
Custom
:
&
ClaudeCustomToolSpec
{
Description
:
"MCP tool description"
,
InputSchema
:
map
[
string
]
any
{
"type"
:
"object"
,
"properties"
:
map
[
string
]
any
{
"param"
:
map
[
string
]
any
{
"type"
:
"string"
},
},
},
},
},
},
expectedLen
:
1
,
description
:
"Custom类型工具应该从Custom字段读取description和input_schema"
,
},
{
name
:
"Mixed standard and custom tools"
,
tools
:
[]
ClaudeTool
{
{
Name
:
"standard_tool"
,
Description
:
"Standard tool"
,
InputSchema
:
map
[
string
]
any
{
"type"
:
"object"
},
},
{
Type
:
"custom"
,
Name
:
"custom_tool"
,
Custom
:
&
ClaudeCustomToolSpec
{
Description
:
"Custom tool"
,
InputSchema
:
map
[
string
]
any
{
"type"
:
"object"
},
},
},
},
expectedLen
:
1
,
// 返回一个GeminiToolDeclaration,包含2个function declarations
description
:
"混合标准和custom工具应该都能正确转换"
,
},
{
name
:
"Invalid custom tool - nil Custom field"
,
tools
:
[]
ClaudeTool
{
{
Type
:
"custom"
,
Name
:
"invalid_custom"
,
// Custom 为 nil
},
},
expectedLen
:
0
,
// 应该被跳过
description
:
"Custom字段为nil的custom工具应该被跳过"
,
},
{
name
:
"Invalid custom tool - nil InputSchema"
,
tools
:
[]
ClaudeTool
{
{
Type
:
"custom"
,
Name
:
"invalid_custom"
,
Custom
:
&
ClaudeCustomToolSpec
{
Description
:
"Invalid"
,
// InputSchema 为 nil
},
},
},
expectedLen
:
0
,
// 应该被跳过
description
:
"InputSchema为nil的custom工具应该被跳过"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
buildTools
(
tt
.
tools
)
if
len
(
result
)
!=
tt
.
expectedLen
{
t
.
Errorf
(
"%s: got %d tool declarations, want %d"
,
tt
.
description
,
len
(
result
),
tt
.
expectedLen
)
}
// 验证function declarations存在
if
len
(
result
)
>
0
&&
result
[
0
]
.
FunctionDeclarations
!=
nil
{
if
len
(
result
[
0
]
.
FunctionDeclarations
)
!=
len
(
tt
.
tools
)
{
t
.
Errorf
(
"%s: got %d function declarations, want %d"
,
tt
.
description
,
len
(
result
[
0
]
.
FunctionDeclarations
),
len
(
tt
.
tools
))
}
}
})
}
}
backend/internal/pkg/claude/constants.go
View file @
c5c12d4c
...
@@ -16,12 +16,6 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
...
@@ -16,12 +16,6 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const
HaikuBetaHeader
=
BetaOAuth
+
","
+
BetaInterleavedThinking
const
HaikuBetaHeader
=
BetaOAuth
+
","
+
BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
const
ApiKeyBetaHeader
=
BetaClaudeCode
+
","
+
BetaInterleavedThinking
+
","
+
BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
const
ApiKeyHaikuBetaHeader
=
BetaInterleavedThinking
// Claude Code 客户端默认请求头
// Claude Code 客户端默认请求头
var
DefaultHeaders
=
map
[
string
]
string
{
var
DefaultHeaders
=
map
[
string
]
string
{
"User-Agent"
:
"claude-cli/2.0.62 (external, cli)"
,
"User-Agent"
:
"claude-cli/2.0.62 (external, cli)"
,
...
...
backend/internal/repository/concurrency_cache.go
View file @
c5c12d4c
...
@@ -2,9 +2,7 @@ package repository
...
@@ -2,9 +2,7 @@ package repository
import
(
import
(
"context"
"context"
"errors"
"fmt"
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/redis/go-redis/v9"
...
@@ -29,8 +27,6 @@ const (
...
@@ -29,8 +27,6 @@ const (
userSlotKeyPrefix
=
"concurrency:user:"
userSlotKeyPrefix
=
"concurrency:user:"
// 等待队列计数器格式: concurrency:wait:{userID}
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix
=
"concurrency:wait:"
waitQueueKeyPrefix
=
"concurrency:wait:"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix
=
"wait:account:"
// 默认槽位过期时间(分钟),可通过配置覆盖
// 默认槽位过期时间(分钟),可通过配置覆盖
defaultSlotTTLMinutes
=
15
defaultSlotTTLMinutes
=
15
...
@@ -119,29 +115,6 @@ var (
...
@@ -119,29 +115,6 @@ var (
return 1
return 1
`
)
`
)
// incrementAccountWaitScript - account-level wait queue count
incrementAccountWaitScript
=
redis
.
NewScript
(
`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`
)
// decrementWaitScript - same as before
// decrementWaitScript - same as before
decrementWaitScript
=
redis
.
NewScript
(
`
decrementWaitScript
=
redis
.
NewScript
(
`
local current = redis.call('GET', KEYS[1])
local current = redis.call('GET', KEYS[1])
...
@@ -150,78 +123,22 @@ var (
...
@@ -150,78 +123,22 @@ var (
end
end
return 1
return 1
`
)
`
)
// getAccountsLoadBatchScript - batch load query (read-only)
// ARGV[1] = slot TTL (seconds, retained for compatibility)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript
=
redis
.
NewScript
(
`
local result = {}
local i = 2
while i <= #ARGV do
local accountID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, accountID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`
)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
cleanupExpiredSlotsScript
=
redis
.
NewScript
(
`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`
)
)
)
type
concurrencyCache
struct
{
type
concurrencyCache
struct
{
rdb
*
redis
.
Client
rdb
*
redis
.
Client
slotTTLSeconds
int
// 槽位过期时间(秒)
slotTTLSeconds
int
// 槽位过期时间(秒)
waitQueueTTLSeconds
int
// 等待队列过期时间(秒)
}
}
// NewConcurrencyCache 创建并发控制缓存
// NewConcurrencyCache 创建并发控制缓存
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL
func
NewConcurrencyCache
(
rdb
*
redis
.
Client
,
slotTTLMinutes
int
)
service
.
ConcurrencyCache
{
func
NewConcurrencyCache
(
rdb
*
redis
.
Client
,
slotTTLMinutes
int
,
waitQueueTTLSeconds
int
)
service
.
ConcurrencyCache
{
if
slotTTLMinutes
<=
0
{
if
slotTTLMinutes
<=
0
{
slotTTLMinutes
=
defaultSlotTTLMinutes
slotTTLMinutes
=
defaultSlotTTLMinutes
}
}
if
waitQueueTTLSeconds
<=
0
{
waitQueueTTLSeconds
=
slotTTLMinutes
*
60
}
return
&
concurrencyCache
{
return
&
concurrencyCache
{
rdb
:
rdb
,
rdb
:
rdb
,
slotTTLSeconds
:
slotTTLMinutes
*
60
,
slotTTLSeconds
:
slotTTLMinutes
*
60
,
waitQueueTTLSeconds
:
waitQueueTTLSeconds
,
}
}
}
}
...
@@ -238,10 +155,6 @@ func waitQueueKey(userID int64) string {
...
@@ -238,10 +155,6 @@ func waitQueueKey(userID int64) string {
return
fmt
.
Sprintf
(
"%s%d"
,
waitQueueKeyPrefix
,
userID
)
return
fmt
.
Sprintf
(
"%s%d"
,
waitQueueKeyPrefix
,
userID
)
}
}
func
accountWaitKey
(
accountID
int64
)
string
{
return
fmt
.
Sprintf
(
"%s%d"
,
accountWaitKeyPrefix
,
accountID
)
}
// Account slot operations
// Account slot operations
func
(
c
*
concurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
func
(
c
*
concurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
...
@@ -312,75 +225,3 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64)
...
@@ -312,75 +225,3 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64)
_
,
err
:=
decrementWaitScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
})
.
Result
()
_
,
err
:=
decrementWaitScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
})
.
Result
()
return
err
return
err
}
}
// Account wait queue operations
func
(
c
*
concurrencyCache
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
key
:=
accountWaitKey
(
accountID
)
result
,
err
:=
incrementAccountWaitScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
maxWait
,
c
.
waitQueueTTLSeconds
)
.
Int
()
if
err
!=
nil
{
return
false
,
err
}
return
result
==
1
,
nil
}
func
(
c
*
concurrencyCache
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
error
{
key
:=
accountWaitKey
(
accountID
)
_
,
err
:=
decrementWaitScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
})
.
Result
()
return
err
}
func
(
c
*
concurrencyCache
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
key
:=
accountWaitKey
(
accountID
)
val
,
err
:=
c
.
rdb
.
Get
(
ctx
,
key
)
.
Int
()
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
return
0
,
err
}
if
errors
.
Is
(
err
,
redis
.
Nil
)
{
return
0
,
nil
}
return
val
,
nil
}
func
(
c
*
concurrencyCache
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
service
.
AccountWithConcurrency
)
(
map
[
int64
]
*
service
.
AccountLoadInfo
,
error
)
{
if
len
(
accounts
)
==
0
{
return
map
[
int64
]
*
service
.
AccountLoadInfo
{},
nil
}
args
:=
[]
any
{
c
.
slotTTLSeconds
}
for
_
,
acc
:=
range
accounts
{
args
=
append
(
args
,
acc
.
ID
,
acc
.
MaxConcurrency
)
}
result
,
err
:=
getAccountsLoadBatchScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{},
args
...
)
.
Slice
()
if
err
!=
nil
{
return
nil
,
err
}
loadMap
:=
make
(
map
[
int64
]
*
service
.
AccountLoadInfo
)
for
i
:=
0
;
i
<
len
(
result
);
i
+=
4
{
if
i
+
3
>=
len
(
result
)
{
break
}
accountID
,
_
:=
strconv
.
ParseInt
(
fmt
.
Sprintf
(
"%v"
,
result
[
i
]),
10
,
64
)
currentConcurrency
,
_
:=
strconv
.
Atoi
(
fmt
.
Sprintf
(
"%v"
,
result
[
i
+
1
]))
waitingCount
,
_
:=
strconv
.
Atoi
(
fmt
.
Sprintf
(
"%v"
,
result
[
i
+
2
]))
loadRate
,
_
:=
strconv
.
Atoi
(
fmt
.
Sprintf
(
"%v"
,
result
[
i
+
3
]))
loadMap
[
accountID
]
=
&
service
.
AccountLoadInfo
{
AccountID
:
accountID
,
CurrentConcurrency
:
currentConcurrency
,
WaitingCount
:
waitingCount
,
LoadRate
:
loadRate
,
}
}
return
loadMap
,
nil
}
func
(
c
*
concurrencyCache
)
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
{
key
:=
accountSlotKey
(
accountID
)
_
,
err
:=
cleanupExpiredSlotsScript
.
Run
(
ctx
,
c
.
rdb
,
[]
string
{
key
},
c
.
slotTTLSeconds
)
.
Result
()
return
err
}
backend/internal/repository/concurrency_cache_benchmark_test.go
View file @
c5c12d4c
...
@@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) {
...
@@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) {
_
=
rdb
.
Close
()
_
=
rdb
.
Close
()
}()
}()
cache
,
_
:=
NewConcurrencyCache
(
rdb
,
benchSlotTTLMinutes
,
int
(
benchSlotTTL
.
Seconds
())
)
.
(
*
concurrencyCache
)
cache
,
_
:=
NewConcurrencyCache
(
rdb
,
benchSlotTTLMinutes
)
.
(
*
concurrencyCache
)
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
for
_
,
size
:=
range
[]
int
{
10
,
100
,
1000
}
{
for
_
,
size
:=
range
[]
int
{
10
,
100
,
1000
}
{
...
...
backend/internal/repository/concurrency_cache_integration_test.go
View file @
c5c12d4c
...
@@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct {
...
@@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func
(
s
*
ConcurrencyCacheSuite
)
SetupTest
()
{
func
(
s
*
ConcurrencyCacheSuite
)
SetupTest
()
{
s
.
IntegrationRedisSuite
.
SetupTest
()
s
.
IntegrationRedisSuite
.
SetupTest
()
s
.
cache
=
NewConcurrencyCache
(
s
.
rdb
,
testSlotTTLMinutes
,
int
(
testSlotTTL
.
Seconds
())
)
s
.
cache
=
NewConcurrencyCache
(
s
.
rdb
,
testSlotTTLMinutes
)
}
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountSlot_AcquireAndRelease
()
{
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountSlot_AcquireAndRelease
()
{
...
@@ -218,48 +218,6 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
...
@@ -218,48 +218,6 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
require
.
GreaterOrEqual
(
s
.
T
(),
val
,
0
,
"expected non-negative wait count"
)
require
.
GreaterOrEqual
(
s
.
T
(),
val
,
0
,
"expected non-negative wait count"
)
}
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountWaitQueue_IncrementAndDecrement
()
{
accountID
:=
int64
(
30
)
waitKey
:=
fmt
.
Sprintf
(
"%s%d"
,
accountWaitKeyPrefix
,
accountID
)
ok
,
err
:=
s
.
cache
.
IncrementAccountWaitCount
(
s
.
ctx
,
accountID
,
2
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementAccountWaitCount 1"
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
IncrementAccountWaitCount
(
s
.
ctx
,
accountID
,
2
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementAccountWaitCount 2"
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
IncrementAccountWaitCount
(
s
.
ctx
,
accountID
,
2
)
require
.
NoError
(
s
.
T
(),
err
,
"IncrementAccountWaitCount 3"
)
require
.
False
(
s
.
T
(),
ok
,
"expected account wait increment over max to fail"
)
ttl
,
err
:=
s
.
rdb
.
TTL
(
s
.
ctx
,
waitKey
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
,
"TTL account waitKey"
)
s
.
AssertTTLWithin
(
ttl
,
1
*
time
.
Second
,
testSlotTTL
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DecrementAccountWaitCount
(
s
.
ctx
,
accountID
),
"DecrementAccountWaitCount"
)
val
,
err
:=
s
.
rdb
.
Get
(
s
.
ctx
,
waitKey
)
.
Int
()
if
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
require
.
NoError
(
s
.
T
(),
err
,
"Get waitKey"
)
}
require
.
Equal
(
s
.
T
(),
1
,
val
,
"expected account wait count 1"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestAccountWaitQueue_DecrementNoNegative
()
{
accountID
:=
int64
(
301
)
waitKey
:=
fmt
.
Sprintf
(
"%s%d"
,
accountWaitKeyPrefix
,
accountID
)
require
.
NoError
(
s
.
T
(),
s
.
cache
.
DecrementAccountWaitCount
(
s
.
ctx
,
accountID
),
"DecrementAccountWaitCount on non-existent key"
)
val
,
err
:=
s
.
rdb
.
Get
(
s
.
ctx
,
waitKey
)
.
Int
()
if
!
errors
.
Is
(
err
,
redis
.
Nil
)
{
require
.
NoError
(
s
.
T
(),
err
,
"Get waitKey"
)
}
require
.
GreaterOrEqual
(
s
.
T
(),
val
,
0
,
"expected non-negative account wait count after decrement on empty"
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestGetAccountConcurrency_Missing
()
{
func
(
s
*
ConcurrencyCacheSuite
)
TestGetAccountConcurrency_Missing
()
{
// When no slots exist, GetAccountConcurrency should return 0
// When no slots exist, GetAccountConcurrency should return 0
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
999
)
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
999
)
...
@@ -274,139 +232,6 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
...
@@ -274,139 +232,6 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
require
.
Equal
(
s
.
T
(),
0
,
cur
)
require
.
Equal
(
s
.
T
(),
0
,
cur
)
}
}
func
(
s
*
ConcurrencyCacheSuite
)
TestGetAccountsLoadBatch
()
{
s
.
T
()
.
Skip
(
"TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI"
)
// Setup: Create accounts with different load states
account1
:=
int64
(
100
)
account2
:=
int64
(
101
)
account3
:=
int64
(
102
)
// Account 1: 2/3 slots used, 1 waiting
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
account1
,
3
,
"req1"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
account1
,
3
,
"req2"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
IncrementAccountWaitCount
(
s
.
ctx
,
account1
,
5
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Account 2: 1/2 slots used, 0 waiting
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
account2
,
2
,
"req3"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts
:=
[]
service
.
AccountWithConcurrency
{
{
ID
:
account1
,
MaxConcurrency
:
3
},
{
ID
:
account2
,
MaxConcurrency
:
2
},
{
ID
:
account3
,
MaxConcurrency
:
1
},
}
loadMap
,
err
:=
s
.
cache
.
GetAccountsLoadBatch
(
s
.
ctx
,
accounts
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Len
(
s
.
T
(),
loadMap
,
3
)
// Verify account1: (2 + 1) / 3 = 100%
load1
:=
loadMap
[
account1
]
require
.
NotNil
(
s
.
T
(),
load1
)
require
.
Equal
(
s
.
T
(),
account1
,
load1
.
AccountID
)
require
.
Equal
(
s
.
T
(),
2
,
load1
.
CurrentConcurrency
)
require
.
Equal
(
s
.
T
(),
1
,
load1
.
WaitingCount
)
require
.
Equal
(
s
.
T
(),
100
,
load1
.
LoadRate
)
// Verify account2: (1 + 0) / 2 = 50%
load2
:=
loadMap
[
account2
]
require
.
NotNil
(
s
.
T
(),
load2
)
require
.
Equal
(
s
.
T
(),
account2
,
load2
.
AccountID
)
require
.
Equal
(
s
.
T
(),
1
,
load2
.
CurrentConcurrency
)
require
.
Equal
(
s
.
T
(),
0
,
load2
.
WaitingCount
)
require
.
Equal
(
s
.
T
(),
50
,
load2
.
LoadRate
)
// Verify account3: (0 + 0) / 1 = 0%
load3
:=
loadMap
[
account3
]
require
.
NotNil
(
s
.
T
(),
load3
)
require
.
Equal
(
s
.
T
(),
account3
,
load3
.
AccountID
)
require
.
Equal
(
s
.
T
(),
0
,
load3
.
CurrentConcurrency
)
require
.
Equal
(
s
.
T
(),
0
,
load3
.
WaitingCount
)
require
.
Equal
(
s
.
T
(),
0
,
load3
.
LoadRate
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestGetAccountsLoadBatch_Empty
()
{
// Test with empty account list
loadMap
,
err
:=
s
.
cache
.
GetAccountsLoadBatch
(
s
.
ctx
,
[]
service
.
AccountWithConcurrency
{})
require
.
NoError
(
s
.
T
(),
err
)
require
.
Empty
(
s
.
T
(),
loadMap
)
}
func
(
s
*
ConcurrencyCacheSuite
)
TestCleanupExpiredAccountSlots
()
{
accountID
:=
int64
(
200
)
slotKey
:=
fmt
.
Sprintf
(
"%s%d"
,
accountSlotKeyPrefix
,
accountID
)
// Acquire 3 slots
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req1"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req2"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req3"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Verify 3 slots exist
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
3
,
cur
)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now
:=
time
.
Now
()
.
Unix
()
expiredTime
:=
now
-
int64
(
testSlotTTL
.
Seconds
())
-
10
// 10 seconds past TTL
err
=
s
.
rdb
.
ZAdd
(
s
.
ctx
,
slotKey
,
redis
.
Z
{
Score
:
float64
(
expiredTime
),
Member
:
"req1"
})
.
Err
()
require
.
NoError
(
s
.
T
(),
err
)
err
=
s
.
rdb
.
ZAdd
(
s
.
ctx
,
slotKey
,
redis
.
Z
{
Score
:
float64
(
expiredTime
),
Member
:
"req2"
})
.
Err
()
require
.
NoError
(
s
.
T
(),
err
)
// Run cleanup
err
=
s
.
cache
.
CleanupExpiredAccountSlots
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
// Verify only 1 slot remains (req3)
cur
,
err
=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
1
,
cur
)
// Verify req3 still exists
members
,
err
:=
s
.
rdb
.
ZRange
(
s
.
ctx
,
slotKey
,
0
,
-
1
)
.
Result
()
require
.
NoError
(
s
.
T
(),
err
)
require
.
Len
(
s
.
T
(),
members
,
1
)
require
.
Equal
(
s
.
T
(),
"req3"
,
members
[
0
])
}
func
(
s
*
ConcurrencyCacheSuite
)
TestCleanupExpiredAccountSlots_NoExpired
()
{
accountID
:=
int64
(
201
)
// Acquire 2 fresh slots
ok
,
err
:=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req1"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
ok
,
err
=
s
.
cache
.
AcquireAccountSlot
(
s
.
ctx
,
accountID
,
5
,
"req2"
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
True
(
s
.
T
(),
ok
)
// Run cleanup (should not remove anything)
err
=
s
.
cache
.
CleanupExpiredAccountSlots
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
// Verify both slots still exist
cur
,
err
:=
s
.
cache
.
GetAccountConcurrency
(
s
.
ctx
,
accountID
)
require
.
NoError
(
s
.
T
(),
err
)
require
.
Equal
(
s
.
T
(),
2
,
cur
)
}
func
TestConcurrencyCacheSuite
(
t
*
testing
.
T
)
{
func
TestConcurrencyCacheSuite
(
t
*
testing
.
T
)
{
suite
.
Run
(
t
,
new
(
ConcurrencyCacheSuite
))
suite
.
Run
(
t
,
new
(
ConcurrencyCacheSuite
))
}
}
backend/internal/repository/wire.go
View file @
c5c12d4c
...
@@ -15,14 +15,7 @@ import (
...
@@ -15,14 +15,7 @@ import (
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
// 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景
// 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景
func
ProvideConcurrencyCache
(
rdb
*
redis
.
Client
,
cfg
*
config
.
Config
)
service
.
ConcurrencyCache
{
func
ProvideConcurrencyCache
(
rdb
*
redis
.
Client
,
cfg
*
config
.
Config
)
service
.
ConcurrencyCache
{
waitTTLSeconds
:=
int
(
cfg
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
.
Seconds
())
return
NewConcurrencyCache
(
rdb
,
cfg
.
Gateway
.
ConcurrencySlotTTLMinutes
)
if
cfg
.
Gateway
.
Scheduling
.
FallbackWaitTimeout
>
cfg
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
{
waitTTLSeconds
=
int
(
cfg
.
Gateway
.
Scheduling
.
FallbackWaitTimeout
.
Seconds
())
}
if
waitTTLSeconds
<=
0
{
waitTTLSeconds
=
cfg
.
Gateway
.
ConcurrencySlotTTLMinutes
*
60
}
return
NewConcurrencyCache
(
rdb
,
cfg
.
Gateway
.
ConcurrencySlotTTLMinutes
,
waitTTLSeconds
)
}
}
// ProviderSet is the Wire provider set for all repositories
// ProviderSet is the Wire provider set for all repositories
...
...
backend/internal/service/antigravity_gateway_service.go
View file @
c5c12d4c
...
@@ -358,15 +358,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
...
@@ -358,15 +358,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return
nil
,
fmt
.
Errorf
(
"transform request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"transform request: %w"
,
err
)
}
}
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if
bodyJSON
,
err
:=
json
.
Marshal
(
geminiBody
);
err
==
nil
{
truncated
:=
string
(
bodyJSON
)
if
len
(
truncated
)
>
2000
{
truncated
=
truncated
[
:
2000
]
+
"..."
}
log
.
Printf
(
"[Debug] Transformed Gemini request: %s"
,
truncated
)
}
// 构建上游 action
// 构建上游 action
action
:=
"generateContent"
action
:=
"generateContent"
if
claudeReq
.
Stream
{
if
claudeReq
.
Stream
{
...
...
backend/internal/service/concurrency_service.go
View file @
c5c12d4c
...
@@ -18,11 +18,6 @@ type ConcurrencyCache interface {
...
@@ -18,11 +18,6 @@ type ConcurrencyCache interface {
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
GetAccountConcurrency
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
GetAccountConcurrency
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
// 账号等待队列(账号级)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
error
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
// 用户槽位管理
// 用户槽位管理
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
AcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
AcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
...
@@ -32,12 +27,6 @@ type ConcurrencyCache interface {
...
@@ -32,12 +27,6 @@ type ConcurrencyCache interface {
// 等待队列计数(只在首次创建时设置 TTL)
// 等待队列计数(只在首次创建时设置 TTL)
IncrementWaitCount
(
ctx
context
.
Context
,
userID
int64
,
maxWait
int
)
(
bool
,
error
)
IncrementWaitCount
(
ctx
context
.
Context
,
userID
int64
,
maxWait
int
)
(
bool
,
error
)
DecrementWaitCount
(
ctx
context
.
Context
,
userID
int64
)
error
DecrementWaitCount
(
ctx
context
.
Context
,
userID
int64
)
error
// 批量负载查询(只读)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
}
}
// generateRequestID generates a unique request ID for concurrency slot tracking
// generateRequestID generates a unique request ID for concurrency slot tracking
...
@@ -72,18 +61,6 @@ type AcquireResult struct {
...
@@ -72,18 +61,6 @@ type AcquireResult struct {
ReleaseFunc
func
()
// Must be called when done (typically via defer)
ReleaseFunc
func
()
// Must be called when done (typically via defer)
}
}
type
AccountWithConcurrency
struct
{
ID
int64
MaxConcurrency
int
}
type
AccountLoadInfo
struct
{
AccountID
int64
CurrentConcurrency
int
WaitingCount
int
LoadRate
int
// 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
// Returns a release function that MUST be called when the request completes.
...
@@ -200,42 +177,6 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
...
@@ -200,42 +177,6 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
}
}
}
}
// IncrementAccountWaitCount increments the wait queue counter for an account.
func
(
s
*
ConcurrencyService
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
if
s
.
cache
==
nil
{
return
true
,
nil
}
result
,
err
:=
s
.
cache
.
IncrementAccountWaitCount
(
ctx
,
accountID
,
maxWait
)
if
err
!=
nil
{
log
.
Printf
(
"Warning: increment wait count failed for account %d: %v"
,
accountID
,
err
)
return
true
,
nil
}
return
result
,
nil
}
// DecrementAccountWaitCount decrements the wait queue counter for an account.
func
(
s
*
ConcurrencyService
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
{
if
s
.
cache
==
nil
{
return
}
bgCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
defer
cancel
()
if
err
:=
s
.
cache
.
DecrementAccountWaitCount
(
bgCtx
,
accountID
);
err
!=
nil
{
log
.
Printf
(
"Warning: decrement wait count failed for account %d: %v"
,
accountID
,
err
)
}
}
// GetAccountWaitingCount gets current wait queue count for an account.
func
(
s
*
ConcurrencyService
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
if
s
.
cache
==
nil
{
return
0
,
nil
}
return
s
.
cache
.
GetAccountWaitingCount
(
ctx
,
accountID
)
}
// CalculateMaxWait calculates the maximum wait queue size for a user
// CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots
// maxWait = userConcurrency + defaultExtraWaitSlots
func
CalculateMaxWait
(
userConcurrency
int
)
int
{
func
CalculateMaxWait
(
userConcurrency
int
)
int
{
...
@@ -245,57 +186,6 @@ func CalculateMaxWait(userConcurrency int) int {
...
@@ -245,57 +186,6 @@ func CalculateMaxWait(userConcurrency int) int {
return
userConcurrency
+
defaultExtraWaitSlots
return
userConcurrency
+
defaultExtraWaitSlots
}
}
// GetAccountsLoadBatch returns load info for multiple accounts.
func
(
s
*
ConcurrencyService
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
if
s
.
cache
==
nil
{
return
map
[
int64
]
*
AccountLoadInfo
{},
nil
}
return
s
.
cache
.
GetAccountsLoadBatch
(
ctx
,
accounts
)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func
(
s
*
ConcurrencyService
)
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
{
if
s
.
cache
==
nil
{
return
nil
}
return
s
.
cache
.
CleanupExpiredAccountSlots
(
ctx
,
accountID
)
}
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
func
(
s
*
ConcurrencyService
)
StartSlotCleanupWorker
(
accountRepo
AccountRepository
,
interval
time
.
Duration
)
{
if
s
==
nil
||
s
.
cache
==
nil
||
accountRepo
==
nil
||
interval
<=
0
{
return
}
runCleanup
:=
func
()
{
listCtx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
5
*
time
.
Second
)
accounts
,
err
:=
accountRepo
.
ListSchedulable
(
listCtx
)
cancel
()
if
err
!=
nil
{
log
.
Printf
(
"Warning: list schedulable accounts failed: %v"
,
err
)
return
}
for
_
,
account
:=
range
accounts
{
accountCtx
,
accountCancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
err
:=
s
.
cache
.
CleanupExpiredAccountSlots
(
accountCtx
,
account
.
ID
)
accountCancel
()
if
err
!=
nil
{
log
.
Printf
(
"Warning: cleanup expired slots failed for account %d: %v"
,
account
.
ID
,
err
)
}
}
}
go
func
()
{
ticker
:=
time
.
NewTicker
(
interval
)
defer
ticker
.
Stop
()
runCleanup
()
for
range
ticker
.
C
{
runCleanup
()
}
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
// Returns a map of accountID -> current concurrency count
func
(
s
*
ConcurrencyService
)
GetAccountConcurrencyBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
func
(
s
*
ConcurrencyService
)
GetAccountConcurrencyBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
{
...
...
backend/internal/service/gateway_multiplatform_test.go
View file @
c5c12d4c
...
@@ -261,34 +261,6 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
...
@@ -261,34 +261,6 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"同优先级应选择最久未用的账户"
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"同优先级应选择最久未用的账户"
)
}
}
func
TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeApiKey
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeOAuth
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
""
,
"gemini-2.5-pro"
,
nil
,
PlatformGemini
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"同优先级且未使用时应优先选择OAuth账户"
)
}
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
func
TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts
(
t
*
testing
.
T
)
{
func
TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
...
@@ -604,32 +576,6 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
...
@@ -604,32 +576,6 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
func
TestGatewayService_selectAccountWithMixedScheduling
(
t
*
testing
.
T
)
{
func
TestGatewayService_selectAccountWithMixedScheduling
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
t
.
Run
(
"混合调度-Gemini优先选择OAuth账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeApiKey
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeOAuth
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"gemini-2.5-pro"
,
nil
,
PlatformGemini
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"同优先级且未使用时应优先选择OAuth账户"
)
})
t
.
Run
(
"混合调度-包含启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"混合调度-包含启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
accounts
:
[]
Account
{
...
@@ -837,160 +783,3 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
...
@@ -837,160 +783,3 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
})
})
}
}
}
}
// mockConcurrencyService for testing
type
mockConcurrencyService
struct
{
accountLoads
map
[
int64
]
*
AccountLoadInfo
accountWaitCounts
map
[
int64
]
int
acquireResults
map
[
int64
]
bool
}
func
(
m
*
mockConcurrencyService
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
if
m
.
accountLoads
==
nil
{
return
map
[
int64
]
*
AccountLoadInfo
{},
nil
}
result
:=
make
(
map
[
int64
]
*
AccountLoadInfo
)
for
_
,
acc
:=
range
accounts
{
if
load
,
ok
:=
m
.
accountLoads
[
acc
.
ID
];
ok
{
result
[
acc
.
ID
]
=
load
}
else
{
result
[
acc
.
ID
]
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
,
CurrentConcurrency
:
0
,
WaitingCount
:
0
,
LoadRate
:
0
,
}
}
}
return
result
,
nil
}
func
(
m
*
mockConcurrencyService
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
if
m
.
accountWaitCounts
==
nil
{
return
0
,
nil
}
return
m
.
accountWaitCounts
[
accountID
],
nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func
TestGatewayService_SelectAccountWithLoadAwareness
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"禁用负载批量查询-降级到传统选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
// No concurrency service
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
,
"应选择优先级最高的账号"
)
})
t
.
Run
(
"无ConcurrencyService-降级到传统选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"应选择优先级最高的账号"
)
})
t
.
Run
(
"排除账号-不选择被排除的账号"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
excludedIDs
:=
map
[
int64
]
struct
{}{
1
:
{}}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"不应选择被排除的账号"
)
})
t
.
Run
(
"无可用账号-返回错误"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{},
accountsByID
:
map
[
int64
]
*
Account
{},
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
})
}
backend/internal/service/gateway_service.go
View file @
c5c12d4c
This diff is collapsed.
Click to expand it.
backend/internal/service/gemini_messages_compat_service.go
View file @
c5c12d4c
...
@@ -2278,13 +2278,11 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
...
@@ -2278,13 +2278,11 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
"properties"
:
map
[
string
]
any
{},
"properties"
:
map
[
string
]
any
{},
}
}
}
}
// 清理 JSON Schema
cleanedParams
:=
cleanToolSchema
(
params
)
funcDecls
=
append
(
funcDecls
,
map
[
string
]
any
{
funcDecls
=
append
(
funcDecls
,
map
[
string
]
any
{
"name"
:
name
,
"name"
:
name
,
"description"
:
desc
,
"description"
:
desc
,
"parameters"
:
cleanedP
arams
,
"parameters"
:
p
arams
,
})
})
}
}
...
@@ -2298,41 +2296,6 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
...
@@ -2298,41 +2296,6 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
}
}
}
}
// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段
func
cleanToolSchema
(
schema
any
)
any
{
if
schema
==
nil
{
return
nil
}
switch
v
:=
schema
.
(
type
)
{
case
map
[
string
]
any
:
cleaned
:=
make
(
map
[
string
]
any
)
for
key
,
value
:=
range
v
{
// 跳过不支持的字段
if
key
==
"$schema"
||
key
==
"$id"
||
key
==
"$ref"
||
key
==
"additionalProperties"
||
key
==
"minLength"
||
key
==
"maxLength"
||
key
==
"minItems"
||
key
==
"maxItems"
{
continue
}
// 递归清理嵌套对象
cleaned
[
key
]
=
cleanToolSchema
(
value
)
}
// 规范化 type 字段为大写
if
typeVal
,
ok
:=
cleaned
[
"type"
]
.
(
string
);
ok
{
cleaned
[
"type"
]
=
strings
.
ToUpper
(
typeVal
)
}
return
cleaned
case
[]
any
:
cleaned
:=
make
([]
any
,
len
(
v
))
for
i
,
item
:=
range
v
{
cleaned
[
i
]
=
cleanToolSchema
(
item
)
}
return
cleaned
default
:
return
v
}
}
func
convertClaudeGenerationConfig
(
req
map
[
string
]
any
)
map
[
string
]
any
{
func
convertClaudeGenerationConfig
(
req
map
[
string
]
any
)
map
[
string
]
any
{
out
:=
make
(
map
[
string
]
any
)
out
:=
make
(
map
[
string
]
any
)
if
mt
,
ok
:=
asInt
(
req
[
"max_tokens"
]);
ok
&&
mt
>
0
{
if
mt
,
ok
:=
asInt
(
req
[
"max_tokens"
]);
ok
&&
mt
>
0
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment