Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
2fe8932c
Unverified
Commit
2fe8932c
authored
Feb 03, 2026
by
Call White
Committed by
GitHub
Feb 03, 2026
Browse files
Merge pull request #3 from cyhhao/main
merge to main
parents
2f2e76f9
adb77af1
Changes
267
Show whitespace changes
Inline
Side-by-side
backend/internal/service/auth_service.go
View file @
2fe8932c
...
...
@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return
""
,
nil
,
ErrServiceUnavailable
}
// 应用优惠码(如果提供)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
{
// 应用优惠码(如果提供
且功能已启用
)
if
promoCode
!=
""
&&
s
.
promoService
!=
nil
&&
s
.
settingService
!=
nil
&&
s
.
settingService
.
IsPromoCodeEnabled
(
ctx
)
{
if
err
:=
s
.
promoService
.
ApplyPromoCode
(
ctx
,
user
.
ID
,
promoCode
);
err
!=
nil
{
// 优惠码应用失败不影响注册,只记录日志
log
.
Printf
(
"[Auth] Failed to apply promo code for user %d: %v"
,
user
.
ID
,
err
)
...
...
@@ -580,3 +580,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 生成新token
return
s
.
GenerateToken
(
user
)
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证且 SMTP 配置正确
func
(
s
*
AuthService
)
IsPasswordResetEnabled
(
ctx
context
.
Context
)
bool
{
if
s
.
settingService
==
nil
{
return
false
}
// Must have email verification enabled and SMTP configured
if
!
s
.
settingService
.
IsEmailVerifyEnabled
(
ctx
)
{
return
false
}
return
s
.
settingService
.
IsPasswordResetEnabled
(
ctx
)
}
// preparePasswordReset validates the password reset request and returns necessary data
// Returns (siteName, resetURL, shouldProceed)
// shouldProceed is false when we should silently return success (to prevent enumeration)
func
(
s
*
AuthService
)
preparePasswordReset
(
ctx
context
.
Context
,
email
,
frontendBaseURL
string
)
(
string
,
string
,
bool
)
{
// Check if user exists (but don't reveal this to the caller)
user
,
err
:=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
// Security: Log but don't reveal that user doesn't exist
log
.
Printf
(
"[Auth] Password reset requested for non-existent email: %s"
,
email
)
return
""
,
""
,
false
}
log
.
Printf
(
"[Auth] Database error checking email for password reset: %v"
,
err
)
return
""
,
""
,
false
}
// Check if user is active
if
!
user
.
IsActive
()
{
log
.
Printf
(
"[Auth] Password reset requested for inactive user: %s"
,
email
)
return
""
,
""
,
false
}
// Get site name
siteName
:=
"Sub2API"
if
s
.
settingService
!=
nil
{
siteName
=
s
.
settingService
.
GetSiteName
(
ctx
)
}
// Build reset URL base
resetURL
:=
fmt
.
Sprintf
(
"%s/reset-password"
,
strings
.
TrimSuffix
(
frontendBaseURL
,
"/"
))
return
siteName
,
resetURL
,
true
}
// RequestPasswordReset 请求密码重置(同步发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func
(
s
*
AuthService
)
RequestPasswordReset
(
ctx
context
.
Context
,
email
,
frontendBaseURL
string
)
error
{
if
!
s
.
IsPasswordResetEnabled
(
ctx
)
{
return
infraerrors
.
Forbidden
(
"PASSWORD_RESET_DISABLED"
,
"password reset is not enabled"
)
}
if
s
.
emailService
==
nil
{
return
ErrServiceUnavailable
}
siteName
,
resetURL
,
shouldProceed
:=
s
.
preparePasswordReset
(
ctx
,
email
,
frontendBaseURL
)
if
!
shouldProceed
{
return
nil
// Silent success to prevent enumeration
}
if
err
:=
s
.
emailService
.
SendPasswordResetEmail
(
ctx
,
email
,
siteName
,
resetURL
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to send password reset email to %s: %v"
,
email
,
err
)
return
nil
// Silent success to prevent enumeration
}
log
.
Printf
(
"[Auth] Password reset email sent to: %s"
,
email
)
return
nil
}
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func
(
s
*
AuthService
)
RequestPasswordResetAsync
(
ctx
context
.
Context
,
email
,
frontendBaseURL
string
)
error
{
if
!
s
.
IsPasswordResetEnabled
(
ctx
)
{
return
infraerrors
.
Forbidden
(
"PASSWORD_RESET_DISABLED"
,
"password reset is not enabled"
)
}
if
s
.
emailQueueService
==
nil
{
return
ErrServiceUnavailable
}
siteName
,
resetURL
,
shouldProceed
:=
s
.
preparePasswordReset
(
ctx
,
email
,
frontendBaseURL
)
if
!
shouldProceed
{
return
nil
// Silent success to prevent enumeration
}
if
err
:=
s
.
emailQueueService
.
EnqueuePasswordReset
(
email
,
siteName
,
resetURL
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Failed to enqueue password reset email for %s: %v"
,
email
,
err
)
return
nil
// Silent success to prevent enumeration
}
log
.
Printf
(
"[Auth] Password reset email enqueued for: %s"
,
email
)
return
nil
}
// ResetPassword 重置密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func
(
s
*
AuthService
)
ResetPassword
(
ctx
context
.
Context
,
email
,
token
,
newPassword
string
)
error
{
// Check if password reset is enabled
if
!
s
.
IsPasswordResetEnabled
(
ctx
)
{
return
infraerrors
.
Forbidden
(
"PASSWORD_RESET_DISABLED"
,
"password reset is not enabled"
)
}
if
s
.
emailService
==
nil
{
return
ErrServiceUnavailable
}
// Verify and consume the reset token (one-time use)
if
err
:=
s
.
emailService
.
ConsumePasswordResetToken
(
ctx
,
email
,
token
);
err
!=
nil
{
return
err
}
// Get user
user
,
err
:=
s
.
userRepo
.
GetByEmail
(
ctx
,
email
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrUserNotFound
)
{
return
ErrInvalidResetToken
// Token was valid but user was deleted
}
log
.
Printf
(
"[Auth] Database error getting user for password reset: %v"
,
err
)
return
ErrServiceUnavailable
}
// Check if user is active
if
!
user
.
IsActive
()
{
return
ErrUserNotActive
}
// Hash new password
hashedPassword
,
err
:=
s
.
HashPassword
(
newPassword
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"hash password: %w"
,
err
)
}
// Update password and increment TokenVersion
user
.
PasswordHash
=
hashedPassword
user
.
TokenVersion
++
// Invalidate all existing tokens
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
log
.
Printf
(
"[Auth] Database error updating password for user %d: %v"
,
user
.
ID
,
err
)
return
ErrServiceUnavailable
}
log
.
Printf
(
"[Auth] Password reset successful for user: %s"
,
email
)
return
nil
}
backend/internal/service/auth_service_register_test.go
View file @
2fe8932c
...
...
@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return
nil
}
func
(
s
*
emailCacheStub
)
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
{
return
nil
,
nil
}
func
(
s
*
emailCacheStub
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
emailCacheStub
)
DeletePasswordResetToken
(
ctx
context
.
Context
,
email
string
)
error
{
return
nil
}
func
(
s
*
emailCacheStub
)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
{
return
false
}
func
(
s
*
emailCacheStub
)
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
newAuthService
(
repo
*
userRepoStub
,
settings
map
[
string
]
string
,
emailCache
EmailCache
)
*
AuthService
{
cfg
:=
&
config
.
Config
{
JWT
:
config
.
JWTConfig
{
...
...
backend/internal/service/claude_token_provider.go
View file @
2fe8932c
...
...
@@ -181,8 +181,18 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
// 3. 存入缓存
(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if
p
.
tokenCache
!=
nil
{
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
isStale
&&
latestAccount
!=
nil
{
// 版本过时,使用 DB 中的最新 token
slog
.
Debug
(
"claude_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
accessToken
=
latestAccount
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found after version check"
)
}
// 不写入缓存,让下次请求重新处理
}
else
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
...
...
@@ -203,6 +213,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
slog
.
Warn
(
"claude_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
}
return
accessToken
,
nil
}
backend/internal/service/dashboard_aggregation_service.go
View file @
2fe8932c
...
...
@@ -21,11 +21,15 @@ var (
ErrDashboardBackfillDisabled
=
errors
.
New
(
"仪表盘聚合回填已禁用"
)
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
ErrDashboardBackfillTooLarge
=
errors
.
New
(
"回填时间跨度过大"
)
errDashboardAggregationRunning
=
errors
.
New
(
"聚合作业正在运行"
)
)
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
type
DashboardAggregationRepository
interface
{
AggregateRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
// RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。
// 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。
RecomputeRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
UpdateAggregationWatermark
(
ctx
context
.
Context
,
aggregatedAt
time
.
Time
)
error
CleanupAggregates
(
ctx
context
.
Context
,
hourlyCutoff
,
dailyCutoff
time
.
Time
)
error
...
...
@@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
return
nil
}
// TriggerRecomputeRange 触发指定范围的重新计算(异步)。
// 与 TriggerBackfill 不同:
// - 不依赖 backfill_enabled(这是内部一致性修复)
// - 不更新 watermark(避免影响正常增量聚合游标)
func
(
s
*
DashboardAggregationService
)
TriggerRecomputeRange
(
start
,
end
time
.
Time
)
error
{
if
s
==
nil
||
s
.
repo
==
nil
{
return
errors
.
New
(
"聚合服务未初始化"
)
}
if
!
s
.
cfg
.
Enabled
{
return
errors
.
New
(
"聚合服务已禁用"
)
}
if
!
end
.
After
(
start
)
{
return
errors
.
New
(
"重新计算时间范围无效"
)
}
go
func
()
{
const
maxRetries
=
3
for
i
:=
0
;
i
<
maxRetries
;
i
++
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
defaultDashboardAggregationBackfillTimeout
)
err
:=
s
.
recomputeRange
(
ctx
,
start
,
end
)
cancel
()
if
err
==
nil
{
return
}
if
!
errors
.
Is
(
err
,
errDashboardAggregationRunning
)
{
log
.
Printf
(
"[DashboardAggregation] 重新计算失败: %v"
,
err
)
return
}
time
.
Sleep
(
5
*
time
.
Second
)
}
log
.
Printf
(
"[DashboardAggregation] 重新计算放弃: 聚合作业持续占用"
)
}()
return
nil
}
func
(
s
*
DashboardAggregationService
)
recomputeRecentDays
()
{
days
:=
s
.
cfg
.
RecomputeDays
if
days
<=
0
{
...
...
@@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
}
}
func
(
s
*
DashboardAggregationService
)
recomputeRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
if
!
atomic
.
CompareAndSwapInt32
(
&
s
.
running
,
0
,
1
)
{
return
errDashboardAggregationRunning
}
defer
atomic
.
StoreInt32
(
&
s
.
running
,
0
)
jobStart
:=
time
.
Now
()
.
UTC
()
if
err
:=
s
.
repo
.
RecomputeRange
(
ctx
,
start
,
end
);
err
!=
nil
{
return
err
}
log
.
Printf
(
"[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)"
,
start
.
UTC
()
.
Format
(
time
.
RFC3339
),
end
.
UTC
()
.
Format
(
time
.
RFC3339
),
time
.
Since
(
jobStart
)
.
String
(),
)
return
nil
}
func
(
s
*
DashboardAggregationService
)
runScheduledAggregation
()
{
if
!
atomic
.
CompareAndSwapInt32
(
&
s
.
running
,
0
,
1
)
{
return
...
...
@@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
func
(
s
*
DashboardAggregationService
)
backfillRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
if
!
atomic
.
CompareAndSwapInt32
(
&
s
.
running
,
0
,
1
)
{
return
err
ors
.
New
(
"聚合作业正在运行"
)
return
err
DashboardAggregationRunning
}
defer
atomic
.
StoreInt32
(
&
s
.
running
,
0
)
...
...
backend/internal/service/dashboard_aggregation_service_test.go
View file @
2fe8932c
...
...
@@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
return
s
.
aggregateErr
}
func
(
s
*
dashboardAggregationRepoTestStub
)
RecomputeRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
return
s
.
AggregateRange
(
ctx
,
start
,
end
)
}
func
(
s
*
dashboardAggregationRepoTestStub
)
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
{
return
s
.
watermark
,
nil
}
...
...
backend/internal/service/dashboard_service.go
View file @
2fe8932c
...
...
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return
stats
,
nil
}
func
(
s
*
DashboardService
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
trend
,
err
:=
s
.
usageRepo
.
GetUsageTrendWithFilters
(
ctx
,
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
stream
)
func
(
s
*
DashboardService
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
trend
,
err
:=
s
.
usageRepo
.
GetUsageTrendWithFilters
(
ctx
,
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
stream
,
billingType
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get usage trend with filters: %w"
,
err
)
}
return
trend
,
nil
}
func
(
s
*
DashboardService
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
stream
)
func
(
s
*
DashboardService
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
,
billingType
*
int8
)
([]
usagestats
.
ModelStat
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
stream
,
billingType
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get model stats with filters: %w"
,
err
)
}
...
...
backend/internal/service/dashboard_service_test.go
View file @
2fe8932c
...
...
@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start
return
nil
}
func
(
s
*
dashboardAggregationRepoStub
)
RecomputeRange
(
ctx
context
.
Context
,
start
,
end
time
.
Time
)
error
{
return
nil
}
func
(
s
*
dashboardAggregationRepoStub
)
GetAggregationWatermark
(
ctx
context
.
Context
)
(
time
.
Time
,
error
)
{
if
s
.
err
!=
nil
{
return
time
.
Time
{},
s
.
err
...
...
backend/internal/service/domain_constants.go
View file @
2fe8932c
...
...
@@ -71,6 +71,8 @@ const (
// 注册设置
SettingKeyRegistrationEnabled
=
"registration_enabled"
// 是否开放注册
SettingKeyEmailVerifyEnabled
=
"email_verify_enabled"
// 是否开启邮件验证
SettingKeyPromoCodeEnabled
=
"promo_code_enabled"
// 是否启用优惠码功能
SettingKeyPasswordResetEnabled
=
"password_reset_enabled"
// 是否启用忘记密码功能(需要先开启邮件验证)
// 邮件服务设置
SettingKeySMTPHost
=
"smtp_host"
// SMTP服务器地址
...
...
@@ -86,6 +88,9 @@ const (
SettingKeyTurnstileSiteKey
=
"turnstile_site_key"
// Turnstile Site Key
SettingKeyTurnstileSecretKey
=
"turnstile_secret_key"
// Turnstile Secret Key
// TOTP 双因素认证设置
SettingKeyTotpEnabled
=
"totp_enabled"
// 是否启用 TOTP 2FA 功能
// LinuxDo Connect OAuth 登录设置
SettingKeyLinuxDoConnectEnabled
=
"linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID
=
"linuxdo_connect_client_id"
...
...
@@ -100,6 +105,9 @@ const (
SettingKeyContactInfo
=
"contact_info"
// 客服联系方式
SettingKeyDocURL
=
"doc_url"
// 文档链接
SettingKeyHomeContent
=
"home_content"
// 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
SettingKeyHideCcsImportButton
=
"hide_ccs_import_button"
// 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled
=
"purchase_subscription_enabled"
// 是否展示“购买订阅”页面入口
SettingKeyPurchaseSubscriptionURL
=
"purchase_subscription_url"
// “购买订阅”页面 URL(作为 iframe src)
// 默认配置
SettingKeyDefaultConcurrency
=
"default_concurrency"
// 新用户默认并发量
...
...
backend/internal/service/email_queue_service.go
View file @
2fe8932c
...
...
@@ -8,11 +8,18 @@ import (
"time"
)
// Task type constants
const
(
TaskTypeVerifyCode
=
"verify_code"
TaskTypePasswordReset
=
"password_reset"
)
// EmailTask 邮件发送任务
type
EmailTask
struct
{
Email
string
SiteName
string
TaskType
string
// "verify_code"
TaskType
string
// "verify_code" or "password_reset"
ResetURL
string
// Only used for password_reset task type
}
// EmailQueueService 异步邮件队列服务
...
...
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
defer
cancel
()
switch
task
.
TaskType
{
case
"v
erify
_c
ode
"
:
case
TaskTypeV
erify
C
ode
:
if
err
:=
s
.
emailService
.
SendVerifyCode
(
ctx
,
task
.
Email
,
task
.
SiteName
);
err
!=
nil
{
log
.
Printf
(
"[EmailQueue] Worker %d failed to send verify code to %s: %v"
,
workerID
,
task
.
Email
,
err
)
}
else
{
log
.
Printf
(
"[EmailQueue] Worker %d sent verify code to %s"
,
workerID
,
task
.
Email
)
}
case
TaskTypePasswordReset
:
if
err
:=
s
.
emailService
.
SendPasswordResetEmailWithCooldown
(
ctx
,
task
.
Email
,
task
.
SiteName
,
task
.
ResetURL
);
err
!=
nil
{
log
.
Printf
(
"[EmailQueue] Worker %d failed to send password reset to %s: %v"
,
workerID
,
task
.
Email
,
err
)
}
else
{
log
.
Printf
(
"[EmailQueue] Worker %d sent password reset to %s"
,
workerID
,
task
.
Email
)
}
default
:
log
.
Printf
(
"[EmailQueue] Worker %d unknown task type: %s"
,
workerID
,
task
.
TaskType
)
}
...
...
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
task
:=
EmailTask
{
Email
:
email
,
SiteName
:
siteName
,
TaskType
:
"v
erify
_c
ode
"
,
TaskType
:
TaskTypeV
erify
C
ode
,
}
select
{
...
...
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
}
}
// EnqueuePasswordReset 将密码重置邮件任务加入队列
func
(
s
*
EmailQueueService
)
EnqueuePasswordReset
(
email
,
siteName
,
resetURL
string
)
error
{
task
:=
EmailTask
{
Email
:
email
,
SiteName
:
siteName
,
TaskType
:
TaskTypePasswordReset
,
ResetURL
:
resetURL
,
}
select
{
case
s
.
taskChan
<-
task
:
log
.
Printf
(
"[EmailQueue] Enqueued password reset task for %s"
,
email
)
return
nil
default
:
return
fmt
.
Errorf
(
"email queue is full"
)
}
}
// Stop 停止队列服务
func
(
s
*
EmailQueueService
)
Stop
()
{
close
(
s
.
stopChan
)
...
...
backend/internal/service/email_service.go
View file @
2fe8932c
...
...
@@ -3,11 +3,14 @@ package service
import
(
"context"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"encoding/hex"
"fmt"
"log"
"math/big"
"net/smtp"
"net/url"
"strconv"
"time"
...
...
@@ -19,6 +22,9 @@ var (
ErrInvalidVerifyCode
=
infraerrors
.
BadRequest
(
"INVALID_VERIFY_CODE"
,
"invalid or expired verification code"
)
ErrVerifyCodeTooFrequent
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_TOO_FREQUENT"
,
"please wait before requesting a new code"
)
ErrVerifyCodeMaxAttempts
=
infraerrors
.
TooManyRequests
(
"VERIFY_CODE_MAX_ATTEMPTS"
,
"too many failed attempts, please request a new code"
)
// Password reset errors
ErrInvalidResetToken
=
infraerrors
.
BadRequest
(
"INVALID_RESET_TOKEN"
,
"invalid or expired password reset token"
)
)
// EmailCache defines cache operations for email service
...
...
@@ -26,6 +32,16 @@ type EmailCache interface {
GetVerificationCode
(
ctx
context
.
Context
,
email
string
)
(
*
VerificationCodeData
,
error
)
SetVerificationCode
(
ctx
context
.
Context
,
email
string
,
data
*
VerificationCodeData
,
ttl
time
.
Duration
)
error
DeleteVerificationCode
(
ctx
context
.
Context
,
email
string
)
error
// Password reset token methods
GetPasswordResetToken
(
ctx
context
.
Context
,
email
string
)
(
*
PasswordResetTokenData
,
error
)
SetPasswordResetToken
(
ctx
context
.
Context
,
email
string
,
data
*
PasswordResetTokenData
,
ttl
time
.
Duration
)
error
DeletePasswordResetToken
(
ctx
context
.
Context
,
email
string
)
error
// Password reset email cooldown methods
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown
(
ctx
context
.
Context
,
email
string
)
bool
SetPasswordResetEmailCooldown
(
ctx
context
.
Context
,
email
string
,
ttl
time
.
Duration
)
error
}
// VerificationCodeData represents verification code data
...
...
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
CreatedAt
time
.
Time
}
// PasswordResetTokenData represents password reset token data
type
PasswordResetTokenData
struct
{
Token
string
CreatedAt
time
.
Time
}
const
(
verifyCodeTTL
=
15
*
time
.
Minute
verifyCodeCooldown
=
1
*
time
.
Minute
maxVerifyCodeAttempts
=
5
// Password reset token settings
passwordResetTokenTTL
=
30
*
time
.
Minute
// Password reset email cooldown (prevent email bombing)
passwordResetEmailCooldown
=
30
*
time
.
Second
)
// SMTPConfig SMTP配置
...
...
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
return
ErrVerifyCodeMaxAttempts
}
// 验证码不匹配
if
data
.
Code
!=
code
{
// 验证码不匹配
(constant-time comparison to prevent timing attacks)
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Code
),
[]
byte
(
code
))
!=
1
{
data
.
Attempts
++
if
err
:=
s
.
cache
.
SetVerificationCode
(
ctx
,
email
,
data
,
verifyCodeTTL
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to update verification attempt count: %v"
,
err
)
...
...
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
return
client
.
Quit
()
}
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
func
(
s
*
EmailService
)
GeneratePasswordResetToken
()
(
string
,
error
)
{
bytes
:=
make
([]
byte
,
32
)
if
_
,
err
:=
rand
.
Read
(
bytes
);
err
!=
nil
{
return
""
,
err
}
return
hex
.
EncodeToString
(
bytes
),
nil
}
// SendPasswordResetEmail sends a password reset email with a reset link
func
(
s
*
EmailService
)
SendPasswordResetEmail
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
var
token
string
var
needSaveToken
bool
// Check if token already exists
existing
,
err
:=
s
.
cache
.
GetPasswordResetToken
(
ctx
,
email
)
if
err
==
nil
&&
existing
!=
nil
{
// Token exists, reuse it (allows resending email without generating new token)
token
=
existing
.
Token
needSaveToken
=
false
}
else
{
// Generate new token
token
,
err
=
s
.
GeneratePasswordResetToken
()
if
err
!=
nil
{
return
fmt
.
Errorf
(
"generate token: %w"
,
err
)
}
needSaveToken
=
true
}
// Save token to Redis (only if new token generated)
if
needSaveToken
{
data
:=
&
PasswordResetTokenData
{
Token
:
token
,
CreatedAt
:
time
.
Now
(),
}
if
err
:=
s
.
cache
.
SetPasswordResetToken
(
ctx
,
email
,
data
,
passwordResetTokenTTL
);
err
!=
nil
{
return
fmt
.
Errorf
(
"save reset token: %w"
,
err
)
}
}
// Build full reset URL with URL-encoded token and email
fullResetURL
:=
fmt
.
Sprintf
(
"%s?email=%s&token=%s"
,
resetURL
,
url
.
QueryEscape
(
email
),
url
.
QueryEscape
(
token
))
// Build email content
subject
:=
fmt
.
Sprintf
(
"[%s] 密码重置请求"
,
siteName
)
body
:=
s
.
buildPasswordResetEmailBody
(
fullResetURL
,
siteName
)
// Send email
if
err
:=
s
.
SendEmail
(
ctx
,
email
,
subject
,
body
);
err
!=
nil
{
return
fmt
.
Errorf
(
"send email: %w"
,
err
)
}
return
nil
}
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
func
(
s
*
EmailService
)
SendPasswordResetEmailWithCooldown
(
ctx
context
.
Context
,
email
,
siteName
,
resetURL
string
)
error
{
// Check email cooldown to prevent email bombing
if
s
.
cache
.
IsPasswordResetEmailInCooldown
(
ctx
,
email
)
{
log
.
Printf
(
"[Email] Password reset email skipped (cooldown): %s"
,
email
)
return
nil
// Silent success to prevent revealing cooldown to attackers
}
// Send email using core method
if
err
:=
s
.
SendPasswordResetEmail
(
ctx
,
email
,
siteName
,
resetURL
);
err
!=
nil
{
return
err
}
// Set cooldown marker (Redis TTL handles expiration)
if
err
:=
s
.
cache
.
SetPasswordResetEmailCooldown
(
ctx
,
email
,
passwordResetEmailCooldown
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to set password reset cooldown for %s: %v"
,
email
,
err
)
}
return
nil
}
// VerifyPasswordResetToken verifies the password reset token without consuming it
func
(
s
*
EmailService
)
VerifyPasswordResetToken
(
ctx
context
.
Context
,
email
,
token
string
)
error
{
data
,
err
:=
s
.
cache
.
GetPasswordResetToken
(
ctx
,
email
)
if
err
!=
nil
||
data
==
nil
{
return
ErrInvalidResetToken
}
// Use constant-time comparison to prevent timing attacks
if
subtle
.
ConstantTimeCompare
([]
byte
(
data
.
Token
),
[]
byte
(
token
))
!=
1
{
return
ErrInvalidResetToken
}
return
nil
}
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
func
(
s
*
EmailService
)
ConsumePasswordResetToken
(
ctx
context
.
Context
,
email
,
token
string
)
error
{
// Verify first
if
err
:=
s
.
VerifyPasswordResetToken
(
ctx
,
email
,
token
);
err
!=
nil
{
return
err
}
// Delete after verification (one-time use)
if
err
:=
s
.
cache
.
DeletePasswordResetToken
(
ctx
,
email
);
err
!=
nil
{
log
.
Printf
(
"[Email] Failed to delete password reset token after consumption: %v"
,
err
)
}
return
nil
}
// buildPasswordResetEmailBody builds the HTML content for password reset email
func
(
s
*
EmailService
)
buildPasswordResetEmailBody
(
resetURL
,
siteName
string
)
string
{
return
fmt
.
Sprintf
(
`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
.button:hover { opacity: 0.9; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
.warning { color: #e74c3c; font-weight: 500; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">密码重置请求</p>
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
<a href="%s" class="button">重置密码</a>
<div class="info">
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
</div>
<div class="link-fallback">
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
<p>%s</p>
</div>
</div>
<div class="footer">
<p>这是一封自动发送的邮件,请勿回复。</p>
</div>
</div>
</body>
</html>
`
,
siteName
,
resetURL
,
resetURL
)
}
backend/internal/service/gateway_beta_test.go
0 → 100644
View file @
2fe8932c
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestMergeAnthropicBeta
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
"foo, oauth-2025-04-20,bar, foo"
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar"
,
got
)
}
func
TestMergeAnthropicBeta_EmptyIncoming
(
t
*
testing
.
T
)
{
got
:=
mergeAnthropicBeta
(
[]
string
{
"oauth-2025-04-20"
,
"interleaved-thinking-2025-05-14"
},
""
,
)
require
.
Equal
(
t
,
"oauth-2025-04-20,interleaved-thinking-2025-05-14"
,
got
)
}
backend/internal/service/gateway_multiplatform_test.go
View file @
2fe8932c
...
...
@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up
func
(
m
*
mockAccountRepoForPlatform
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ClearError
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
nil
}
...
...
@@ -179,6 +182,7 @@ var _ AccountRepository = (*mockAccountRepoForPlatform)(nil)
// mockGatewayCacheForPlatform 单平台测试用的 cache mock
type
mockGatewayCacheForPlatform
struct
{
sessionBindings
map
[
string
]
int64
deletedSessions
map
[
string
]
int
}
func
(
m
*
mockGatewayCacheForPlatform
)
GetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
(
int64
,
error
)
{
...
...
@@ -200,6 +204,18 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
return
nil
}
func
(
m
*
mockGatewayCacheForPlatform
)
DeleteSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
{
if
m
.
sessionBindings
==
nil
{
return
nil
}
if
m
.
deletedSessions
==
nil
{
m
.
deletedSessions
=
make
(
map
[
string
]
int
)
}
m
.
deletedSessions
[
sessionHash
]
++
delete
(
m
.
sessionBindings
,
sessionHash
)
return
nil
}
type
mockGroupRepoForGateway
struct
{
groups
map
[
int64
]
*
Group
getByIDCalls
int
...
...
@@ -623,76 +639,96 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi
})
}
func
TestGatewayService_isModelSupportedByAccount
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
func
TestGatewayService_SelectAccountForModelWithExclusions_ForcePlatform
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
ForcePlatform
,
PlatformAntigravity
)
tests
:=
[]
struct
{
name
string
account
*
Account
model
string
expected
bool
}{
{
name
:
"Antigravity平台-支持claude模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
{
name
:
"Antigravity平台-支持gemini模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"gemini-2.5-flash"
,
expected
:
true
,
},
{
name
:
"Antigravity平台-不支持gpt模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"gpt-4"
,
expected
:
false
,
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
{
name
:
"Anthropic平台-无映射配置-支持所有模型"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
require
.
Equal
(
t
,
PlatformAntigravity
,
acc
.
Platform
)
}
func
TestGatewayService_SelectAccountForModelWithPlatform_RoutedStickySessionClears
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
10
)
requestedModel
:=
"claude-3-5-sonnet-20241022"
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusDisabled
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
{
name
:
"Anthropic平台-有映射配置-只支持配置的模型"
,
account
:
&
Account
{
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Name
:
"route-group"
,
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-opus-4"
:
"x"
}},
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
false
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
requestedModel
:
{
1
,
2
},
},
{
name
:
"Anthropic平台-有映射配置-支持配置的模型"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-sonnet-20241022"
:
"x"
}},
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
svc
.
isModelSupportedByAccount
(
tt
.
account
,
tt
.
model
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
groupRepo
:
groupRepo
,
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
&
groupID
,
"session-123"
,
requestedModel
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
require
.
Equal
(
t
,
1
,
cache
.
deletedSessions
[
"session-123"
])
require
.
Equal
(
t
,
int64
(
2
),
cache
.
sessionBindings
[
"session-123"
])
}
// TestGatewayService_selectAccountWithMixedScheduling 测试混合调度
func
TestGatewayService_selectAccountWithMixedScheduling
(
t
*
testing
.
T
)
{
func
TestGatewayService_SelectAccountForModelWithPlatform_RoutedStickySessionHit
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
11
)
requestedModel
:=
"claude-3-5-sonnet-20241022"
t
.
Run
(
"混合调度-Gemini优先选择OAuth账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
Platform
Gemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeAPIKey
},
{
ID
:
2
,
Platform
:
Platform
Gemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeOAuth
},
{
ID
:
1
,
Platform
:
Platform
Anthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
Platform
Anthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -700,25 +736,48 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-456"
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Name
:
"route-group-hit"
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
requestedModel
:
{
1
,
2
},
},
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
groupRepo
:
groupRepo
,
}
acc
,
err
:=
svc
.
selectAccount
WithMixedScheduling
(
ctx
,
nil
,
""
,
"gemini-2.5-pro"
,
nil
,
Platform
Gemini
)
acc
,
err
:=
svc
.
selectAccount
ForModelWithPlatform
(
ctx
,
&
groupID
,
"session-456"
,
requestedModel
,
nil
,
Platform
Anthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"同优先级且未使用时应优先选择OAuth账户"
)
})
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
}
func
TestGatewayService_SelectAccountForModelWithPlatform_RoutedFallbackToNormal
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
12
)
requestedModel
:=
"claude-3-5-sonnet-20241022"
t
.
Run
(
"混合调度-包含启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAnt
igravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}
},
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAnt
hropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -728,23 +787,48 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cache
:=
&
mockGatewayCacheForPlatform
{}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Name
:
"route-fallback"
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
requestedModel
:
{
99
},
},
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
groupRepo
:
groupRepo
,
}
acc
,
err
:=
svc
.
selectAccount
WithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
acc
,
err
:=
svc
.
selectAccount
ForModelWithPlatform
(
ctx
,
&
groupID
,
""
,
requestedModel
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"应选择优先级最高的账户(包含启用混合调度的antigravity)"
)
})
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
}
func
TestGatewayService_SelectAccountForModelWithPlatform_NoModelSupport
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"混合调度-过滤未启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 未启用 mixed_scheduling
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-haiku-20241022"
:
"claude-3-5-haiku-20241022"
}},
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -760,18 +844,19 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"未启用mixed_scheduling的antigravity账户应被过滤"
)
require
.
Equal
(
t
,
PlatformAnthropic
,
acc
.
Platform
)
})
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"supporting model"
)
}
func
TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"混合调度-粘性会话命中启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
Platform
Anthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
Platform
Antigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}
},
{
ID
:
1
,
Platform
:
Platform
Gemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeAPIKey
},
{
ID
:
2
,
Platform
:
Platform
Gemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeOAuth
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -779,9 +864,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
2
},
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
...
...
@@ -789,17 +872,20 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccount
WithMixedScheduling
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022
"
,
nil
,
Platform
Anthropic
)
acc
,
err
:=
svc
.
selectAccount
ForModelWithPlatform
(
ctx
,
nil
,
""
,
"gemini-2.5-pro
"
,
nil
,
Platform
Gemini
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"应返回粘性会话绑定的启用mixed_scheduling的antigravity账户"
)
})
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
}
func
TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
50
)
t
.
Run
(
"混合调度-粘性会话命中未启用mixed_scheduling的antigravity账户-降级选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAnt
igravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
}
,
// 未启用 mixed_scheduling
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}
},
{
ID
:
2
,
Platform
:
PlatformAnt
hropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -808,7 +894,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-
123
"
:
2
},
sessionBindings
:
map
[
string
]
int64
{
"session-
group
"
:
1
},
}
svc
:=
&
GatewayService
{
...
...
@@ -817,16 +903,26 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccount
WithMixedScheduling
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022
"
,
nil
,
PlatformAnthropic
)
acc
,
err
:=
svc
.
selectAccount
ForModelWithPlatform
(
ctx
,
&
groupID
,
"session-group"
,
"
"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"粘性会话绑定的账户未启用mixed_scheduling,应降级选择anthropic账户"
)
})
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
}
func
TestGatewayService_SelectAccountForModelWithPlatform_StickyModelMismatchFallback
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"混合调度-仅有启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-haiku-20241022"
:
"claude-3-5-haiku-20241022"
}},
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -834,7 +930,9 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-miss"
:
1
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
...
...
@@ -842,17 +940,20 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccount
WithMixedScheduling
(
ctx
,
nil
,
"
"
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
acc
,
err
:=
svc
.
selectAccount
ForModelWithPlatform
(
ctx
,
nil
,
"session-miss
"
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
require
.
Equal
(
t
,
PlatformAntigravity
,
acc
.
Platform
)
})
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
}
func
TestGatewayService_SelectAccountForModelWithPlatform_PreferNeverUsed
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
lastUsed
:=
time
.
Now
()
.
Add
(
-
1
*
time
.
Hour
)
t
.
Run
(
"混合调度-无可用账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 未启用 mixed_scheduling
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
&
lastUsed
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -868,171 +969,1505 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
})
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
}
// TestAccount_IsMixedSchedulingEnabled 测试混合调度开关检查
func
TestAccount_IsMixedSchedulingEnabled
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
Account
func
TestGatewayService_SelectAccountForModelWithPlatform_NoAccounts
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{},
accountsByID
:
map
[
int64
]
*
Account
{},
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountForModelWithPlatform
(
ctx
,
nil
,
""
,
""
,
nil
,
PlatformAnthropic
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
}
func
TestGatewayService_isModelSupportedByAccount
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
tests
:=
[]
struct
{
name
string
account
*
Account
model
string
expected
bool
}{
{
name
:
"非antigravity平台-返回false"
,
account
:
Account
{
Platform
:
PlatformAnthropic
},
expected
:
false
,
},
{
name
:
"antigravity平台-无extra-返回false"
,
account
:
Account
{
Platform
:
PlatformAntigravity
},
expected
:
false
,
name
:
"Antigravity平台-支持claude模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
{
name
:
"antigravity平台-extra无mixed_scheduling-返回false"
,
account
:
Account
{
Platform
:
PlatformAntigravity
,
Extra
:
map
[
string
]
any
{}},
expected
:
false
,
name
:
"Antigravity平台-支持gemini模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"gemini-2.5-flash"
,
expected
:
true
,
},
{
name
:
"antigravity平台-mixed_scheduling=false-返回false"
,
account
:
Account
{
Platform
:
PlatformAntigravity
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
false
}},
name
:
"Antigravity平台-不支持gpt模型"
,
account
:
&
Account
{
Platform
:
PlatformAntigravity
},
model
:
"gpt-4"
,
expected
:
false
,
},
{
name
:
"antigravity平台-mixed_scheduling=true-返回true"
,
account
:
Account
{
Platform
:
PlatformAntigravity
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
name
:
"Anthropic平台-无映射配置-支持所有模型"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
{
name
:
"antigravity平台-mixed_scheduling非bool类型-返回false"
,
account
:
Account
{
Platform
:
PlatformAntigravity
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
"true"
}},
name
:
"Anthropic平台-有映射配置-只支持配置的模型"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-opus-4"
:
"x"
}},
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
false
,
},
{
name
:
"Anthropic平台-有映射配置-支持配置的模型"
,
account
:
&
Account
{
Platform
:
PlatformAnthropic
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-sonnet-20241022"
:
"x"
}},
},
model
:
"claude-3-5-sonnet-20241022"
,
expected
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
tt
.
account
.
IsMixedSchedulingEnabled
(
)
got
:=
svc
.
isModelSupportedByAccount
(
tt
.
account
,
tt
.
model
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
// mockConcurrencyService for testing
type
mockConcurrencyService
struct
{
accountLoads
map
[
int64
]
*
AccountLoadInfo
accountWaitCounts
map
[
int64
]
int
acquireResults
map
[
int64
]
bool
}
// TestGatewayService_selectAccountWithMixedScheduling 测试混合调度
func
TestGatewayService_selectAccountWithMixedScheduling
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
func
(
m
*
mockConcurrencyService
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
if
m
.
accountLoads
==
nil
{
return
map
[
int64
]
*
AccountLoadInfo
{},
nil
}
result
:=
make
(
map
[
int64
]
*
AccountLoadInfo
)
for
_
,
acc
:=
range
accounts
{
if
load
,
ok
:=
m
.
accountLoads
[
acc
.
ID
];
ok
{
result
[
acc
.
ID
]
=
load
}
else
{
result
[
acc
.
ID
]
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
,
CurrentConcurrency
:
0
,
WaitingCount
:
0
,
LoadRate
:
0
,
}
t
.
Run
(
"混合调度-Gemini优先选择OAuth账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeAPIKey
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeOAuth
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
return
result
,
nil
}
func
(
m
*
mockConcurrencyService
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
if
m
.
accountWaitCounts
==
nil
{
return
0
,
nil
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
return
m
.
accountWaitCounts
[
accountID
],
nil
}
type
mockConcurrencyCache
struct
{
acquireAccountCalls
int
loadBatchCalls
int
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"gemini-2.5-pro"
,
nil
,
PlatformGemini
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"同优先级且未使用时应优先选择OAuth账户"
)
})
func
(
m
*
mockConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
m
.
acquireAccountCalls
++
return
true
,
nil
}
t
.
Run
(
"混合调度-包含启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
func
(
m
*
mockConcurrencyCache
)
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
{
return
nil
}
cache
:=
&
mockGatewayCacheForPlatform
{}
func
(
m
*
mockConcurrencyCache
)
GetAccountConcurrency
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
func
(
m
*
mockConcurrencyCache
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"应选择优先级最高的账户(包含启用混合调度的antigravity)"
)
})
func
(
m
*
mockConcurrencyCache
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
t
.
Run
(
"混合调度-路由优先选择路由账号"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
30
)
requestedModel
:=
"claude-3-5-sonnet-20241022"
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
func
(
m
*
mockConcurrencyCache
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
cache
:=
&
mockGatewayCacheForPlatform
{}
func
(
m
*
mockConcurrencyCache
)
AcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
return
true
,
nil
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Name
:
"route-mixed-select"
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
requestedModel
:
{
2
},
},
},
},
}
func
(
m
*
mockConcurrencyCache
)
ReleaseUserSlot
(
ctx
context
.
Context
,
userID
int64
,
requestID
string
)
error
{
return
nil
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
groupRepo
:
groupRepo
,
}
func
(
m
*
mockConcurrencyCache
)
GetUserConcurrency
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
&
groupID
,
""
,
requestedModel
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
})
func
(
m
*
mockConcurrencyCache
)
IncrementWaitCount
(
ctx
context
.
Context
,
userID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
}
t
.
Run
(
"混合调度-路由粘性命中"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
31
)
requestedModel
:=
"claude-3-5-sonnet-20241022"
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
},
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
func
(
m
*
mockConcurrencyCache
)
DecrementWaitCount
(
ctx
context
.
Context
,
userID
int64
)
err
or
{
return
nil
}
cache
:=
&
mockGatewayCacheForPlatf
or
m
{
sessionBindings
:
map
[
string
]
int64
{
"session-777"
:
2
},
}
func
(
m
*
mockConcurrencyCache
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
m
.
loadBatchCalls
++
result
:=
make
(
map
[
int64
]
*
AccountLoadInfo
,
len
(
accounts
))
for
_
,
acc
:=
range
accounts
{
result
[
acc
.
ID
]
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
,
CurrentConcurrency
:
0
,
WaitingCount
:
0
,
LoadRate
:
0
,
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Name
:
"route-mixed-sticky"
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
requestedModel
:
{
2
},
},
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
groupRepo
:
groupRepo
,
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
&
groupID
,
"session-777"
,
requestedModel
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
})
t
.
Run
(
"混合调度-路由账号缺失回退"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
32
)
requestedModel
:=
"claude-3-5-sonnet-20241022"
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Name
:
"route-mixed-miss"
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
requestedModel
:
{
99
},
},
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
groupRepo
:
groupRepo
,
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
&
groupID
,
""
,
requestedModel
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
})
t
.
Run
(
"混合调度-路由账号未启用mixed_scheduling回退"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
33
)
requestedModel
:=
"claude-3-5-sonnet-20241022"
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 未启用 mixed_scheduling
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Name
:
"route-mixed-disabled"
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
requestedModel
:
{
2
},
},
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
groupRepo
:
groupRepo
,
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
&
groupID
,
""
,
requestedModel
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
})
t
.
Run
(
"混合调度-路由过滤覆盖"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
35
)
requestedModel
:=
"claude-3-5-sonnet-20241022"
resetAt
:=
time
.
Now
()
.
Add
(
10
*
time
.
Minute
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
false
},
{
ID
:
3
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
4
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"model_rate_limits"
:
map
[
string
]
any
{
"claude_sonnet"
:
map
[
string
]
any
{
"rate_limit_reset_at"
:
resetAt
.
Format
(
time
.
RFC3339
),
},
},
},
},
{
ID
:
5
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-haiku-20241022"
:
"claude-3-5-haiku-20241022"
}},
},
{
ID
:
6
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
7
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Name
:
"route-mixed-filter"
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
requestedModel
:
{
1
,
2
,
3
,
4
,
5
,
6
,
7
},
},
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
groupRepo
:
groupRepo
,
}
excluded
:=
map
[
int64
]
struct
{}{
1
:
{}}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
&
groupID
,
""
,
requestedModel
,
excluded
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
7
),
acc
.
ID
)
})
t
.
Run
(
"混合调度-粘性命中分组账号"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
34
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-group"
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
groupRepo
:
groupRepo
,
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
&
groupID
,
"session-group"
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
})
t
.
Run
(
"混合调度-过滤未启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 未启用 mixed_scheduling
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"未启用mixed_scheduling的antigravity账户应被过滤"
)
require
.
Equal
(
t
,
PlatformAnthropic
,
acc
.
Platform
)
})
t
.
Run
(
"混合调度-粘性会话命中启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
2
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"应返回粘性会话绑定的启用mixed_scheduling的antigravity账户"
)
})
t
.
Run
(
"混合调度-粘性会话命中未启用mixed_scheduling的antigravity账户-降级选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 未启用 mixed_scheduling
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
2
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
,
"粘性会话绑定的账户未启用mixed_scheduling,应降级选择anthropic账户"
)
})
t
.
Run
(
"混合调度-粘性会话不可调度-清理并回退"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusDisabled
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
1
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
"session-123"
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
require
.
Equal
(
t
,
1
,
cache
.
deletedSessions
[
"session-123"
])
require
.
Equal
(
t
,
int64
(
2
),
cache
.
sessionBindings
[
"session-123"
])
})
t
.
Run
(
"混合调度-路由粘性不可调度-清理并回退"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
12
)
requestedModel
:=
"claude-3-5-sonnet-20241022"
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusDisabled
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"session-123"
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Name
:
"route-mixed"
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
requestedModel
:
{
1
,
2
},
},
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
groupRepo
:
groupRepo
,
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
&
groupID
,
"session-123"
,
requestedModel
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
require
.
Equal
(
t
,
1
,
cache
.
deletedSessions
[
"session-123"
])
require
.
Equal
(
t
,
int64
(
2
),
cache
.
sessionBindings
[
"session-123"
])
})
t
.
Run
(
"混合调度-仅有启用mixed_scheduling的antigravity账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
require
.
Equal
(
t
,
PlatformAntigravity
,
acc
.
Platform
)
})
t
.
Run
(
"混合调度-无可用账户"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
// 未启用 mixed_scheduling
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
})
t
.
Run
(
"混合调度-不支持模型返回错误"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-haiku-20241022"
:
"claude-3-5-haiku-20241022"
}},
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"supporting model"
)
})
t
.
Run
(
"混合调度-优先未使用账号"
,
func
(
t
*
testing
.
T
)
{
lastUsed
:=
time
.
Now
()
.
Add
(
-
2
*
time
.
Hour
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
&
lastUsed
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
testConfig
(),
}
acc
,
err
:=
svc
.
selectAccountWithMixedScheduling
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
PlatformAnthropic
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
})
}
// TestAccount_IsMixedSchedulingEnabled 测试混合调度开关检查
func
TestAccount_IsMixedSchedulingEnabled
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
Account
expected
bool
}{
{
name
:
"非antigravity平台-返回false"
,
account
:
Account
{
Platform
:
PlatformAnthropic
},
expected
:
false
,
},
{
name
:
"antigravity平台-无extra-返回false"
,
account
:
Account
{
Platform
:
PlatformAntigravity
},
expected
:
false
,
},
{
name
:
"antigravity平台-extra无mixed_scheduling-返回false"
,
account
:
Account
{
Platform
:
PlatformAntigravity
,
Extra
:
map
[
string
]
any
{}},
expected
:
false
,
},
{
name
:
"antigravity平台-mixed_scheduling=false-返回false"
,
account
:
Account
{
Platform
:
PlatformAntigravity
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
false
}},
expected
:
false
,
},
{
name
:
"antigravity平台-mixed_scheduling=true-返回true"
,
account
:
Account
{
Platform
:
PlatformAntigravity
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
expected
:
true
,
},
{
name
:
"antigravity平台-mixed_scheduling非bool类型-返回false"
,
account
:
Account
{
Platform
:
PlatformAntigravity
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
"true"
}},
expected
:
false
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
tt
.
account
.
IsMixedSchedulingEnabled
()
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
// mockConcurrencyService for testing
type
mockConcurrencyService
struct
{
accountLoads
map
[
int64
]
*
AccountLoadInfo
accountWaitCounts
map
[
int64
]
int
acquireResults
map
[
int64
]
bool
}
func
(
m
*
mockConcurrencyService
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
if
m
.
accountLoads
==
nil
{
return
map
[
int64
]
*
AccountLoadInfo
{},
nil
}
result
:=
make
(
map
[
int64
]
*
AccountLoadInfo
)
for
_
,
acc
:=
range
accounts
{
if
load
,
ok
:=
m
.
accountLoads
[
acc
.
ID
];
ok
{
result
[
acc
.
ID
]
=
load
}
else
{
result
[
acc
.
ID
]
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
,
CurrentConcurrency
:
0
,
WaitingCount
:
0
,
LoadRate
:
0
,
}
}
}
return
result
,
nil
}
func
(
m
*
mockConcurrencyService
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
if
m
.
accountWaitCounts
==
nil
{
return
0
,
nil
}
return
m
.
accountWaitCounts
[
accountID
],
nil
}
type
mockConcurrencyCache
struct
{
acquireAccountCalls
int
loadBatchCalls
int
acquireResults
map
[
int64
]
bool
loadBatchErr
error
loadMap
map
[
int64
]
*
AccountLoadInfo
waitCounts
map
[
int64
]
int
skipDefaultLoad
bool
}
func
(
m
*
mockConcurrencyCache
)
AcquireAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
m
.
acquireAccountCalls
++
if
m
.
acquireResults
!=
nil
{
if
result
,
ok
:=
m
.
acquireResults
[
accountID
];
ok
{
return
result
,
nil
}
}
return
true
,
nil
}
func
(
m
*
mockConcurrencyCache
)
ReleaseAccountSlot
(
ctx
context
.
Context
,
accountID
int64
,
requestID
string
)
error
{
return
nil
}
func
(
m
*
mockConcurrencyCache
)
GetAccountConcurrency
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockConcurrencyCache
)
IncrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
m
*
mockConcurrencyCache
)
DecrementAccountWaitCount
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
func
(
m
*
mockConcurrencyCache
)
GetAccountWaitingCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
{
if
m
.
waitCounts
!=
nil
{
if
count
,
ok
:=
m
.
waitCounts
[
accountID
];
ok
{
return
count
,
nil
}
}
return
0
,
nil
}
func
(
m
*
mockConcurrencyCache
)
AcquireUserSlot
(
ctx
context
.
Context
,
userID
int64
,
maxConcurrency
int
,
requestID
string
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
m
*
mockConcurrencyCache
)
ReleaseUserSlot
(
ctx
context
.
Context
,
userID
int64
,
requestID
string
)
error
{
return
nil
}
func
(
m
*
mockConcurrencyCache
)
GetUserConcurrency
(
ctx
context
.
Context
,
userID
int64
)
(
int
,
error
)
{
return
0
,
nil
}
func
(
m
*
mockConcurrencyCache
)
IncrementWaitCount
(
ctx
context
.
Context
,
userID
int64
,
maxWait
int
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
m
*
mockConcurrencyCache
)
DecrementWaitCount
(
ctx
context
.
Context
,
userID
int64
)
error
{
return
nil
}
func
(
m
*
mockConcurrencyCache
)
GetAccountsLoadBatch
(
ctx
context
.
Context
,
accounts
[]
AccountWithConcurrency
)
(
map
[
int64
]
*
AccountLoadInfo
,
error
)
{
m
.
loadBatchCalls
++
if
m
.
loadBatchErr
!=
nil
{
return
nil
,
m
.
loadBatchErr
}
result
:=
make
(
map
[
int64
]
*
AccountLoadInfo
,
len
(
accounts
))
if
m
.
skipDefaultLoad
&&
m
.
loadMap
!=
nil
{
for
_
,
acc
:=
range
accounts
{
if
load
,
ok
:=
m
.
loadMap
[
acc
.
ID
];
ok
{
result
[
acc
.
ID
]
=
load
}
}
return
result
,
nil
}
for
_
,
acc
:=
range
accounts
{
if
m
.
loadMap
!=
nil
{
if
load
,
ok
:=
m
.
loadMap
[
acc
.
ID
];
ok
{
result
[
acc
.
ID
]
=
load
continue
}
}
result
[
acc
.
ID
]
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
,
CurrentConcurrency
:
0
,
WaitingCount
:
0
,
LoadRate
:
0
,
}
}
return
result
,
nil
}
func
(
m
*
mockConcurrencyCache
)
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func
TestGatewayService_SelectAccountWithLoadAwareness
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"禁用负载批量查询-降级到传统选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
// No concurrency service
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
,
"应选择优先级最高的账号"
)
})
t
.
Run
(
"模型路由-无ConcurrencyService也生效"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
sessionHash
:=
"sticky"
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
sessionHash
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-a"
:
{
1
},
"claude-b"
:
{
2
},
},
},
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
// legacy path
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
sessionHash
,
"claude-b"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"切换到 claude-b 时应按模型路由切换账号"
)
require
.
Equal
(
t
,
int64
(
2
),
cache
.
sessionBindings
[
sessionHash
],
"粘性绑定应更新为路由选择的账号"
)
})
t
.
Run
(
"无ConcurrencyService-降级到传统选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"应选择优先级最高的账号"
)
})
t
.
Run
(
"排除账号-不选择被排除的账号"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
excludedIDs
:=
map
[
int64
]
struct
{}{
1
:
{}}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"不应选择被排除的账号"
)
})
t
.
Run
(
"粘性命中-不调用GetByID"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"sticky"
:
1
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
)
require
.
Equal
(
t
,
0
,
repo
.
getByIDCalls
,
"粘性命中不应调用GetByID"
)
require
.
Equal
(
t
,
0
,
concurrencyCache
.
loadBatchCalls
,
"粘性命中应在负载批量查询前返回"
)
})
t
.
Run
(
"粘性账号不在候选集-回退负载感知选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"sticky"
:
1
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"粘性账号不在候选集时应回退到可用账号"
)
require
.
Equal
(
t
,
0
,
repo
.
getByIDCalls
,
"粘性账号缺失不应回退到GetByID"
)
require
.
Equal
(
t
,
1
,
concurrencyCache
.
loadBatchCalls
,
"应继续进行负载批量查询"
)
})
t
.
Run
(
"粘性账号禁用-清理会话并回退选择"
,
func
(
t
*
testing
.
T
)
{
testCtx
:=
context
.
WithValue
(
ctx
,
ctxkey
.
ForcePlatform
,
PlatformAnthropic
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
false
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
repo
.
listPlatformFunc
=
func
(
ctx
context
.
Context
,
platform
string
)
([]
Account
,
error
)
{
return
repo
.
accounts
,
nil
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"sticky"
:
1
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
testCtx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"粘性账号禁用时应回退到可用账号"
)
updatedID
,
ok
:=
cache
.
sessionBindings
[
"sticky"
]
require
.
True
(
t
,
ok
,
"粘性会话应更新绑定"
)
require
.
Equal
(
t
,
int64
(
2
),
updatedID
,
"粘性会话应绑定到新账号"
)
})
t
.
Run
(
"无可用账号-返回错误"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{},
accountsByID
:
map
[
int64
]
*
Account
{},
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
})
t
.
Run
(
"过滤不可调度账号-限流账号被跳过"
,
func
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
resetAt
:=
now
.
Add
(
10
*
time
.
Minute
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
RateLimitResetAt
:
&
resetAt
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"应跳过限流账号,选择可用账号"
)
})
t
.
Run
(
"过滤不可调度账号-过载账号被跳过"
,
func
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
overloadUntil
:=
now
.
Add
(
10
*
time
.
Minute
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
OverloadUntil
:
&
overloadUntil
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"应跳过过载账号,选择可用账号"
)
})
t
.
Run
(
"粘性账号槽位满-返回粘性等待计划"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"sticky"
:
1
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
cfg
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
=
1
concurrencyCache
:=
&
mockConcurrencyCache
{
acquireResults
:
map
[
int64
]
bool
{
1
:
false
},
waitCounts
:
map
[
int64
]
int
{
1
:
0
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
WaitPlan
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
)
require
.
Equal
(
t
,
0
,
concurrencyCache
.
loadBatchCalls
)
})
t
.
Run
(
"负载批量查询失败-降级旧顺序选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{
loadBatchErr
:
errors
.
New
(
"load batch failed"
),
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"legacy"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
)
require
.
Equal
(
t
,
int64
(
2
),
cache
.
sessionBindings
[
"legacy"
])
})
t
.
Run
(
"模型路由-粘性账号等待计划"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
20
)
sessionHash
:=
"route-sticky"
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
sessionHash
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-3-5-sonnet-20241022"
:
{
1
,
2
},
},
},
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
cfg
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
=
1
concurrencyCache
:=
&
mockConcurrencyCache
{
acquireResults
:
map
[
int64
]
bool
{
1
:
false
},
waitCounts
:
map
[
int64
]
int
{
1
:
0
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
sessionHash
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
WaitPlan
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
)
})
t
.
Run
(
"模型路由-粘性账号命中"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
20
)
sessionHash
:=
"route-hit"
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
sessionHash
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-3-5-sonnet-20241022"
:
{
1
,
2
},
},
},
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
sessionHash
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
)
require
.
Equal
(
t
,
0
,
concurrencyCache
.
loadBatchCalls
)
})
t
.
Run
(
"模型路由-粘性账号缺失-清理并回退"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
22
)
sessionHash
:=
"route-missing"
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
sessionHash
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-3-5-sonnet-20241022"
:
{
1
,
2
},
},
},
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
return
result
,
nil
}
func
(
m
*
mockConcurrencyCache
)
CleanupExpiredAccountSlots
(
ctx
context
.
Context
,
accountID
int64
)
error
{
return
nil
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
sessionHash
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
)
require
.
Equal
(
t
,
1
,
cache
.
deletedSessions
[
sessionHash
])
require
.
Equal
(
t
,
int64
(
2
),
cache
.
sessionBindings
[
sessionHash
])
})
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func
TestGatewayService_SelectAccountWithLoadAwareness
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
t
.
Run
(
"模型路由-按负载选择账号"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
21
)
t
.
Run
(
"禁用负载批量查询-降级到传统选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -1042,31 +2477,54 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
cache
:=
&
mockGatewayCacheForPlatform
{}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-3-5-sonnet-20241022"
:
{
1
,
2
},
},
},
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
LoadRate
:
80
},
2
:
{
AccountID
:
2
,
LoadRate
:
20
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
// No c
oncurrency
s
ervice
concurrencyService
:
NewC
oncurrency
S
ervice
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"
"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
"route
"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
,
"应选择优先级最高的账号"
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
)
require
.
Equal
(
t
,
int64
(
2
),
cache
.
sessionBindings
[
"route"
])
})
t
.
Run
(
"模型路由-无ConcurrencyService也生效"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
sessionHash
:=
"sticky"
t
.
Run
(
"模型路由-路由账号全满返回等待计划"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
23
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}
},
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -1074,9 +2532,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
sessionHash
:
1
},
}
cache
:=
&
mockGatewayCacheForPlatform
{}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
...
...
@@ -1087,8 +2543,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-a"
:
{
1
},
"claude-b"
:
{
2
},
"claude-3-5-sonnet-20241022"
:
{
1
,
2
},
},
},
},
...
...
@@ -1097,27 +2552,37 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{
acquireResults
:
map
[
int64
]
bool
{
1
:
false
,
2
:
false
},
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
LoadRate
:
10
},
2
:
{
AccountID
:
2
,
LoadRate
:
20
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
// legacy path
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
sessionHash
,
"claude-b
"
,
nil
,
""
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
"route-full"
,
"claude-3-5-sonnet-20241022
"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"切换到 claude-b 时应按模型路由切换账号"
)
require
.
Equal
(
t
,
int64
(
2
),
cache
.
sessionBindings
[
sessionHash
],
"粘性绑定应更新为路由选择的账号"
)
require
.
NotNil
(
t
,
result
.
WaitPlan
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
)
})
t
.
Run
(
"无ConcurrencyService-降级到传统选择"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"模型路由-路由账号全满-回退普通选择"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
22
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
3
,
Platform
:
PlatformAnthropic
,
Priority
:
0
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -1127,24 +2592,49 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
cache
:=
&
mockGatewayCacheForPlatform
{}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-3-5-sonnet-20241022"
:
{
1
,
2
},
},
},
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
LoadRate
:
100
},
2
:
{
AccountID
:
2
,
LoadRate
:
100
},
3
:
{
AccountID
:
3
,
LoadRate
:
0
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
)
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"
"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
"fallback
"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"应选择优先级最高的账号"
)
require
.
Equal
(
t
,
int64
(
3
),
result
.
Account
.
ID
)
require
.
Equal
(
t
,
int64
(
3
),
cache
.
sessionBindings
[
"fallback"
])
})
t
.
Run
(
"
排除账号-不选择被排除的账号
"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"
负载批量失败且无法获取-兜底等待
"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
...
...
@@ -1159,27 +2649,34 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{
loadBatchErr
:
errors
.
New
(
"load batch failed"
),
acquireResults
:
map
[
int64
]
bool
{
1
:
false
,
2
:
false
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
)
,
}
excludedIDs
:=
map
[
int64
]
struct
{}{
1
:
{}}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
,
""
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"不应选择被排除的账号"
)
require
.
NotNil
(
t
,
result
.
WaitPlan
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
)
})
t
.
Run
(
"粘性命中-不调用GetByID"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"Gemini负载排序-优先OAuth"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
24
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
Type
:
AccountTypeAPIKey
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
Type
:
AccountTypeOAuth
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -1187,35 +2684,77 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"sticky"
:
1
},
cache
:=
&
mockGatewayCacheForPlatform
{}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformGemini
,
Status
:
StatusActive
,
Hydrated
:
true
,
},
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{}
concurrencyCache
:=
&
mockConcurrencyCache
{
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
LoadRate
:
10
},
2
:
{
AccountID
:
2
,
LoadRate
:
10
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022
"
,
nil
,
""
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
"gemini"
,
"gemini-2.5-pro
"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
)
require
.
Equal
(
t
,
0
,
repo
.
getByIDCalls
,
"粘性命中不应调用GetByID"
)
require
.
Equal
(
t
,
0
,
concurrencyCache
.
loadBatchCalls
,
"粘性命中应在负载批量查询前返回"
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
)
})
t
.
Run
(
"粘性账号不在候选集-回退负载感知选择"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"模型路由-过滤路径覆盖"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
70
)
now
:=
time
.
Now
()
.
Add
(
10
*
time
.
Minute
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
3
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
false
,
Concurrency
:
5
},
{
ID
:
4
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
5
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
Extra
:
map
[
string
]
any
{
"model_rate_limits"
:
map
[
string
]
any
{
"claude_sonnet"
:
map
[
string
]
any
{
"rate_limit_reset_at"
:
now
.
Format
(
time
.
RFC3339
),
},
},
},
},
{
ID
:
6
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"claude-3-5-haiku-20241022"
:
"claude-3-5-haiku-20241022"
}},
},
{
ID
:
7
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -1223,8 +2762,21 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
"sticky"
:
1
},
cache
:=
&
mockGatewayCacheForPlatform
{}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-3-5-sonnet-20241022"
:
{
1
,
2
,
3
,
4
,
5
,
6
},
},
},
},
}
cfg
:=
testConfig
()
...
...
@@ -1234,51 +2786,110 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
excluded
:=
map
[
int64
]
struct
{}{
1
:
{}}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
excluded
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"粘性账号不在候选集时应回退到可用账号"
)
require
.
Equal
(
t
,
0
,
repo
.
getByIDCalls
,
"粘性账号缺失不应回退到GetByID"
)
require
.
Equal
(
t
,
1
,
concurrencyCache
.
loadBatchCalls
,
"应继续进行负载批量查询"
)
require
.
Equal
(
t
,
int64
(
7
),
result
.
Account
.
ID
)
})
t
.
Run
(
"无可用账号-返回错误"
,
func
(
t
*
testing
.
T
)
{
t
.
Run
(
"ClaudeCode限制-回退分组"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
60
)
fallbackID
:=
int64
(
61
)
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{},
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ClaudeCodeOnly
:
true
,
FallbackGroupID
:
func
()
*
int64
{
v
:=
fallbackID
return
&
v
}(),
},
fallbackID
:
{
ID
:
fallbackID
,
Platform
:
PlatformGemini
,
Status
:
StatusActive
,
Hydrated
:
true
,
},
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
groupRepo
:
groupRepo
,
cache
:
&
mockGatewayCacheForPlatform
{},
cfg
:
cfg
,
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
""
,
"gemini-2.5-pro"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
)
})
t
.
Run
(
"ClaudeCode限制-无降级返回错误"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
62
)
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ClaudeCodeOnly
:
true
,
},
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
svc
:=
&
GatewayService
{
accountRepo
:
&
mockAccountRepoForPlatform
{},
groupRepo
:
groupRepo
,
cache
:
&
mockGatewayCacheForPlatform
{},
cfg
:
cfg
,
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contain
s
(
t
,
err
.
Err
or
(),
"no available accounts"
)
require
.
ErrorI
s
(
t
,
err
,
Err
ClaudeCodeOnly
)
})
t
.
Run
(
"过滤不可调度账号-限流账号被跳过"
,
func
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
resetAt
:=
now
.
Add
(
10
*
time
.
Minute
)
t
.
Run
(
"负载可用但无法获取槽位-兜底等待"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
RateLimitResetAt
:
&
resetAt
},
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
...
...
@@ -1288,31 +2899,37 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{
acquireResults
:
map
[
int64
]
bool
{
1
:
false
,
2
:
false
},
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
LoadRate
:
10
},
2
:
{
AccountID
:
2
,
LoadRate
:
20
},
},
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
)
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"
wait
"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"应跳过限流账号,选择可用账号"
)
require
.
NotNil
(
t
,
result
.
WaitPlan
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
)
})
t
.
Run
(
"过滤不可调度账号-过载账号被跳过"
,
func
(
t
*
testing
.
T
)
{
now
:=
time
.
Now
()
overloadUntil
:=
now
.
Add
(
10
*
time
.
Minute
)
t
.
Run
(
"负载信息缺失-使用默认负载"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
OverloadUntil
:
&
overloadUntil
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
...
...
@@ -1321,21 +2938,29 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
}
cache
:=
&
mockGatewayCacheForPlatform
{}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
false
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
concurrencyCache
:=
&
mockConcurrencyCache
{
loadMap
:
map
[
int64
]
*
AccountLoadInfo
{
1
:
{
AccountID
:
1
,
LoadRate
:
50
},
},
skipDefaultLoad
:
true
,
}
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
)
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"
missing-load
"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"应跳过过载账号,选择可用账号"
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
)
})
}
...
...
backend/internal/service/gateway_oauth_metadata_test.go
0 → 100644
View file @
2fe8932c
package
service
import
(
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func
TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
System
:
nil
,
Messages
:
nil
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{},
// intentionally missing account_uuid / claude_user_id
}
fp
:=
&
Fingerprint
{
ClientID
:
"deadbeef"
}
// should be used as user id in legacy format
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
fp
)
require
.
NotEmpty
(
t
,
got
)
// Legacy format: user_{client}_account__session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
func
TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent
(
t
*
testing
.
T
)
{
svc
:=
&
GatewayService
{}
parsed
:=
&
ParsedRequest
{
Model
:
"claude-sonnet-4-5"
,
Stream
:
true
,
MetadataUserID
:
""
,
}
account
:=
&
Account
{
ID
:
123
,
Type
:
AccountTypeOAuth
,
Extra
:
map
[
string
]
any
{
"account_uuid"
:
"acc-uuid"
,
"claude_user_id"
:
"clientid123"
,
"anthropic_user_id"
:
""
,
},
}
got
:=
svc
.
buildOAuthMetadataUserID
(
parsed
,
account
,
nil
)
require
.
NotEmpty
(
t
,
got
)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re
:=
regexp
.
MustCompile
(
`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`
)
require
.
True
(
t
,
re
.
MatchString
(
got
),
"unexpected user_id format: %s"
,
got
)
}
backend/internal/service/gateway_prompt_test.go
View file @
2fe8932c
...
...
@@ -2,6 +2,7 @@ package service
import
(
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/require"
...
...
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
}
func
TestInjectClaudeCodePrompt
(
t
*
testing
.
T
)
{
claudePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
tests
:=
[]
struct
{
name
string
body
string
...
...
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system
:
"Custom prompt"
,
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom prompt"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom prompt"
,
},
{
name
:
"string system equals Claude Code prompt"
,
...
...
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Custom"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Custom"
,
},
{
name
:
"array system with existing Claude Code prompt (should dedupe)"
,
...
...
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped)
wantSystemLen
:
2
,
wantFirstText
:
claudeCodeSystemPrompt
,
wantSecondText
:
"
Other"
,
wantSecondText
:
claudePrefix
+
"
\n\n
Other"
,
},
{
name
:
"empty array"
,
...
...
backend/internal/service/gateway_sanitize_test.go
0 → 100644
View file @
2fe8932c
package
service
import
(
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func
TestSanitizeOpenCodeText_RewritesCanonicalSentence
(
t
*
testing
.
T
)
{
in
:=
"You are OpenCode, the best coding agent on the planet."
got
:=
sanitizeSystemText
(
in
)
require
.
Equal
(
t
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
got
)
}
func
TestSanitizeToolDescription_DoesNotRewriteKeywords
(
t
*
testing
.
T
)
{
in
:=
"OpenCode and opencode are mentioned."
got
:=
sanitizeToolDescription
(
in
)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require
.
Equal
(
t
,
in
,
got
)
}
backend/internal/service/gateway_service.go
View file @
2fe8932c
...
...
@@ -11,6 +11,8 @@ import (
"fmt"
"io"
"log"
"log/slog"
mathrand
"math/rand"
"net/http"
"os"
"regexp"
...
...
@@ -37,15 +39,27 @@ const (
claudeAPICountTokensURL
=
"https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL
=
time
.
Hour
// 粘性会话TTL
defaultMaxLineSize
=
40
*
1024
*
1024
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
// to match real Claude CLI traffic as closely as possible. When we need a visual
// separator between system blocks, we add "\n\n" at concatenation time.
claudeCodeSystemPrompt
=
"You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
)
const
(
claudeMimicDebugInfoKey
=
"claude_mimic_debug_info"
)
func
(
s
*
GatewayService
)
debugModelRoutingEnabled
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
"SUB2API_DEBUG_MODEL_ROUTING"
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
(
s
*
GatewayService
)
debugClaudeMimicEnabled
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
"SUB2API_DEBUG_CLAUDE_MIMIC"
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
shortSessionHash
(
sessionHash
string
)
string
{
if
sessionHash
==
""
{
return
""
...
...
@@ -56,6 +70,138 @@ func shortSessionHash(sessionHash string) string {
return
sessionHash
[
:
8
]
}
func
redactAuthHeaderValue
(
v
string
)
string
{
v
=
strings
.
TrimSpace
(
v
)
if
v
==
""
{
return
""
}
// Keep scheme for debugging, redact secret.
if
strings
.
HasPrefix
(
strings
.
ToLower
(
v
),
"bearer "
)
{
return
"Bearer [redacted]"
}
return
"[redacted]"
}
func
safeHeaderValueForLog
(
key
string
,
v
string
)
string
{
key
=
strings
.
ToLower
(
strings
.
TrimSpace
(
key
))
switch
key
{
case
"authorization"
,
"x-api-key"
:
return
redactAuthHeaderValue
(
v
)
default
:
return
strings
.
TrimSpace
(
v
)
}
}
func
extractSystemPreviewFromBody
(
body
[]
byte
)
string
{
if
len
(
body
)
==
0
{
return
""
}
sys
:=
gjson
.
GetBytes
(
body
,
"system"
)
if
!
sys
.
Exists
()
{
return
""
}
switch
{
case
sys
.
IsArray
()
:
for
_
,
item
:=
range
sys
.
Array
()
{
if
!
item
.
IsObject
()
{
continue
}
if
strings
.
EqualFold
(
item
.
Get
(
"type"
)
.
String
(),
"text"
)
{
if
t
:=
item
.
Get
(
"text"
)
.
String
();
strings
.
TrimSpace
(
t
)
!=
""
{
return
t
}
}
}
return
""
case
sys
.
Type
==
gjson
.
String
:
return
sys
.
String
()
default
:
return
""
}
}
func
buildClaudeMimicDebugLine
(
req
*
http
.
Request
,
body
[]
byte
,
account
*
Account
,
tokenType
string
,
mimicClaudeCode
bool
)
string
{
if
req
==
nil
{
return
""
}
// Only log a minimal fingerprint to avoid leaking user content.
interesting
:=
[]
string
{
"user-agent"
,
"x-app"
,
"anthropic-dangerous-direct-browser-access"
,
"anthropic-version"
,
"anthropic-beta"
,
"x-stainless-lang"
,
"x-stainless-package-version"
,
"x-stainless-os"
,
"x-stainless-arch"
,
"x-stainless-runtime"
,
"x-stainless-runtime-version"
,
"x-stainless-retry-count"
,
"x-stainless-timeout"
,
"authorization"
,
"x-api-key"
,
"content-type"
,
"accept"
,
"x-stainless-helper-method"
,
}
h
:=
make
([]
string
,
0
,
len
(
interesting
))
for
_
,
k
:=
range
interesting
{
if
v
:=
req
.
Header
.
Get
(
k
);
v
!=
""
{
h
=
append
(
h
,
fmt
.
Sprintf
(
"%s=%q"
,
k
,
safeHeaderValueForLog
(
k
,
v
)))
}
}
metaUserID
:=
strings
.
TrimSpace
(
gjson
.
GetBytes
(
body
,
"metadata.user_id"
)
.
String
())
sysPreview
:=
strings
.
TrimSpace
(
extractSystemPreviewFromBody
(
body
))
// Truncate preview to keep logs sane.
if
len
(
sysPreview
)
>
300
{
sysPreview
=
sysPreview
[
:
300
]
+
"..."
}
sysPreview
=
strings
.
ReplaceAll
(
sysPreview
,
"
\n
"
,
"
\\
n"
)
sysPreview
=
strings
.
ReplaceAll
(
sysPreview
,
"
\r
"
,
"
\\
r"
)
aid
:=
int64
(
0
)
aname
:=
""
if
account
!=
nil
{
aid
=
account
.
ID
aname
=
account
.
Name
}
return
fmt
.
Sprintf
(
"url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}"
,
req
.
URL
.
String
(),
aid
,
aname
,
tokenType
,
mimicClaudeCode
,
metaUserID
,
sysPreview
,
strings
.
Join
(
h
,
" "
),
)
}
func
logClaudeMimicDebug
(
req
*
http
.
Request
,
body
[]
byte
,
account
*
Account
,
tokenType
string
,
mimicClaudeCode
bool
)
{
line
:=
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
if
line
==
""
{
return
}
log
.
Printf
(
"[ClaudeMimicDebug] %s"
,
line
)
}
func
isClaudeCodeCredentialScopeError
(
msg
string
)
bool
{
m
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
msg
))
if
m
==
""
{
return
false
}
return
strings
.
Contains
(
m
,
"only authorized for use with claude code"
)
&&
strings
.
Contains
(
m
,
"cannot be used for other api requests"
)
}
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var
(
...
...
@@ -69,7 +215,6 @@ var (
modelFieldRe
=
regexp
.
MustCompile
(
`"model"\s*:\s*"([^"]+)"`
)
toolDescAbsPathRe
=
regexp
.
MustCompile
(
`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`
)
toolDescWinPathRe
=
regexp
.
MustCompile
(
`(?i)[A-Z]:\\[^\s,\)"'\]]+`
)
opencodeTextRe
=
regexp
.
MustCompile
(
`(?i)opencode`
)
claudeToolNameOverrides
=
map
[
string
]
string
{
"bash"
:
"Bash"
,
...
...
@@ -134,11 +279,24 @@ var allowedHeaders = map[string]bool{
"content-type"
:
true
,
}
// GatewayCache defines cache operations for gateway service
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
//
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type
GatewayCache
interface
{
// GetSessionAccountID 获取粘性会话绑定的账号 ID
// Get the account ID bound to a sticky session
GetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
(
int64
,
error
)
// SetSessionAccountID 设置粘性会话与账号的绑定关系
// Set the binding between sticky session and account
SetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
,
accountID
int64
,
ttl
time
.
Duration
)
error
// RefreshSessionTTL 刷新粘性会话的过期时间
// Refresh the expiration time of a sticky session
RefreshSessionTTL
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
,
ttl
time
.
Duration
)
error
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
...
...
@@ -149,6 +307,28 @@ func derefGroupID(groupID *int64) int64 {
return
*
groupID
}
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// or within temporary unschedulable period.
// This ensures subsequent requests won't continue using unavailable accounts.
func
shouldClearStickySession
(
account
*
Account
)
bool
{
if
account
==
nil
{
return
false
}
if
account
.
Status
==
StatusError
||
account
.
Status
==
StatusDisabled
||
!
account
.
Schedulable
{
return
true
}
if
account
.
TempUnschedulableUntil
!=
nil
&&
time
.
Now
()
.
Before
(
*
account
.
TempUnschedulableUntil
)
{
return
true
}
return
false
}
type
AccountWaitPlan
struct
{
AccountID
int64
MaxConcurrency
int
...
...
@@ -305,6 +485,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
return
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
accountID
,
stickySessionTTL
)
}
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
// Returns 0 if no binding exists or on error.
func
(
s
*
GatewayService
)
GetCachedSessionAccountID
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
)
(
int64
,
error
)
{
if
sessionHash
==
""
||
s
.
cache
==
nil
{
return
0
,
nil
}
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
!=
nil
{
return
0
,
err
}
return
accountID
,
nil
}
func
(
s
*
GatewayService
)
extractCacheableContent
(
parsed
*
ParsedRequest
)
string
{
if
parsed
==
nil
{
return
""
...
...
@@ -503,12 +696,21 @@ func normalizeParamNameForOpenCode(name string, cache map[string]string) string
return
name
}
func
sanitizeOpenCodeText
(
text
string
)
string
{
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
// We intentionally avoid broad keyword replacement in system prompts to prevent
// accidentally changing user-provided instructions.
func
sanitizeSystemText
(
text
string
)
string
{
if
text
==
""
{
return
text
}
text
=
strings
.
ReplaceAll
(
text
,
"OpenCode"
,
"Claude Code"
)
text
=
opencodeTextRe
.
ReplaceAllString
(
text
,
"Claude"
)
// Some clients include a fixed OpenCode identity sentence. Anthropic may treat
// this as a non-Claude-Code fingerprint, so rewrite it to the canonical
// Claude Code banner before generic "OpenCode"/"opencode" replacements.
text
=
strings
.
ReplaceAll
(
text
,
"You are OpenCode, the best coding agent on the planet."
,
strings
.
TrimSpace
(
claudeCodeSystemPrompt
),
)
return
text
}
...
...
@@ -518,7 +720,9 @@ func sanitizeToolDescription(description string) string {
}
description
=
toolDescAbsPathRe
.
ReplaceAllString
(
description
,
"[path]"
)
description
=
toolDescWinPathRe
.
ReplaceAllString
(
description
,
"[path]"
)
return
sanitizeOpenCodeText
(
description
)
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
return
description
}
func
normalizeToolInputSchema
(
inputSchema
any
,
cache
map
[
string
]
string
)
{
...
...
@@ -593,7 +797,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
system
,
ok
:=
req
[
"system"
];
ok
{
switch
v
:=
system
.
(
type
)
{
case
string
:
sanitized
:=
sanitize
OpenCode
Text
(
v
)
sanitized
:=
sanitize
System
Text
(
v
)
if
sanitized
!=
v
{
req
[
"system"
]
=
sanitized
}
...
...
@@ -610,7 +814,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if
!
ok
||
text
==
""
{
continue
}
sanitized
:=
sanitize
OpenCode
Text
(
text
)
sanitized
:=
sanitize
System
Text
(
text
)
if
sanitized
!=
text
{
block
[
"text"
]
=
sanitized
}
...
...
@@ -743,17 +947,15 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
if
parsed
.
MetadataUserID
!=
""
{
return
""
}
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
if
accountUUID
==
""
{
return
""
}
userID
:=
strings
.
TrimSpace
(
account
.
GetClaudeUserID
())
if
userID
==
""
&&
fp
!=
nil
{
userID
=
fp
.
ClientID
}
if
userID
==
""
{
return
""
// Fall back to a random, well-formed client id so we can still satisfy
// Claude Code OAuth requirements when account metadata is incomplete.
userID
=
generateClientID
()
}
sessionHash
:=
s
.
GenerateSessionHash
(
parsed
)
...
...
@@ -762,7 +964,14 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
seed
:=
fmt
.
Sprintf
(
"%d::%s"
,
account
.
ID
,
sessionHash
)
sessionID
=
generateSessionUUID
(
seed
)
}
// Prefer the newer format that includes account_uuid (if present),
// otherwise fall back to the legacy Claude Code format.
accountUUID
:=
strings
.
TrimSpace
(
account
.
GetExtraString
(
"account_uuid"
))
if
accountUUID
!=
""
{
return
fmt
.
Sprintf
(
"user_%s_account_%s_session_%s"
,
userID
,
accountUUID
,
sessionID
)
}
return
fmt
.
Sprintf
(
"user_%s_account__session_%s"
,
userID
,
sessionID
)
}
func
generateSessionUUID
(
seed
string
)
string
{
...
...
@@ -819,11 +1028,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// metadataUserID:
原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
// metadataUserID:
已废弃参数,会话限制现在统一使用 sessionHash
func
(
s
*
GatewayService
)
SelectAccountWithLoadAwareness
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
metadataUserID
string
)
(
*
AccountSelectionResult
,
error
)
{
// 调试日志:记录调度入口参数
excludedIDsList
:=
make
([]
int64
,
0
,
len
(
excludedIDs
))
for
id
:=
range
excludedIDs
{
excludedIDsList
=
append
(
excludedIDsList
,
id
)
}
slog
.
Debug
(
"account_scheduling_starting"
,
"group_id"
,
derefGroupID
(
groupID
),
"model"
,
requestedModel
,
"session"
,
shortSessionHash
(
sessionHash
),
"excluded_ids"
,
excludedIDsList
)
cfg
:=
s
.
schedulingConfig
()
// 提取会话 UUID(用于会话数量限制)
sessionUUID
:=
extractSessionUUID
(
metadataUserID
)
var
stickyAccountID
int64
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
...
...
@@ -849,18 +1067,39 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
if
s
.
concurrencyService
==
nil
||
!
cfg
.
LoadBatchEnabled
{
account
,
err
:=
s
.
SelectAccountForModelWithExclusions
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
// 复制排除列表,用于会话限制拒绝时的重试
localExcluded
:=
make
(
map
[
int64
]
struct
{})
for
k
,
v
:=
range
excludedIDs
{
localExcluded
[
k
]
=
v
}
for
{
account
,
err
:=
s
.
SelectAccountForModelWithExclusions
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
localExcluded
)
if
err
!=
nil
{
return
nil
,
err
}
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
account
.
ID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符)
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
result
.
ReleaseFunc
()
// 释放槽位
localExcluded
[
account
.
ID
]
=
struct
{}{}
// 排除此账号
continue
// 重新选择
}
return
&
AccountSelectionResult
{
Account
:
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
// 对于等待计划的情况,也需要先检查会话限制
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
localExcluded
[
account
.
ID
]
=
struct
{}{}
continue
}
if
stickyAccountID
>
0
&&
stickyAccountID
==
account
.
ID
&&
s
.
concurrencyService
!=
nil
{
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
account
.
ID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
...
...
@@ -885,6 +1124,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
},
},
nil
}
}
platform
,
hasForcePlatform
,
err
:=
s
.
resolvePlatform
(
ctx
,
groupID
,
group
)
if
err
!=
nil
{
...
...
@@ -999,7 +1239,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
stickyAccountID
,
stickyAccount
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
stickyAccount
,
session
UUID
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
stickyAccount
,
session
Hash
)
{
result
.
ReleaseFunc
()
// 释放槽位
// 继续到负载感知选择
}
else
{
...
...
@@ -1017,6 +1257,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
stickyAccountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
// 会话数量限制检查(等待计划也需要占用会话配额)
if
!
s
.
checkAndRegisterSession
(
ctx
,
stickyAccount
,
sessionHash
)
{
// 会话限制已满,继续到负载感知选择
}
else
{
return
&
AccountSelectionResult
{
Account
:
stickyAccount
,
WaitPlan
:
&
AccountWaitPlan
{
...
...
@@ -1027,8 +1271,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
},
},
nil
}
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
}
else
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
}
}
...
...
@@ -1086,7 +1333,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
item
.
account
.
ID
,
item
.
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
session
UUID
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
session
Hash
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
...
...
@@ -1104,21 +1351,27 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
acc
:=
routingAvailable
[
0
]
.
account
// 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的)
// 遍历找到第一个满足会话限制的账号
for
_
,
item
:=
range
routingAvailable
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
sessionHash
)
{
continue
// 会话限制已满,尝试下一个
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
acc
.
ID
)
log
.
Printf
(
"[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
item
.
account
.
ID
)
}
return
&
AccountSelectionResult
{
Account
:
acc
,
Account
:
item
.
account
,
WaitPlan
:
&
AccountWaitPlan
{
AccountID
:
acc
.
ID
,
MaxConcurrency
:
acc
.
Concurrency
,
AccountID
:
item
.
account
.
ID
,
MaxConcurrency
:
item
.
account
.
Concurrency
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
}
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
log
.
Printf
(
"[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection"
,
requestedModel
)
}
...
...
@@ -1129,7 +1382,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
!
isExcluded
(
accountID
)
{
account
,
ok
:=
accountByID
[
accountID
]
if
ok
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
if
ok
{
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky
:=
shouldClearStickySession
(
account
)
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
s
.
isAccountAllowedForPlatform
(
account
,
platform
,
useMixed
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
&&
...
...
@@ -1137,7 +1397,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionUUID
)
{
// Session count limit check
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续到 Layer 2
}
else
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
...
...
@@ -1151,6 +1412,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
// 会话数量限制检查(等待计划也需要占用会话配额)
// Session count limit check (wait plan also requires session quota)
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionHash
)
{
// 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
}
else
{
return
&
AccountSelectionResult
{
Account
:
account
,
WaitPlan
:
&
AccountWaitPlan
{
...
...
@@ -1164,6 +1431,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
}
}
}
// ============ Layer 2: 负载感知选择 ============
candidates
:=
make
([]
*
Account
,
0
,
len
(
accounts
))
...
...
@@ -1208,7 +1477,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap
,
err
:=
s
.
concurrencyService
.
GetAccountsLoadBatch
(
ctx
,
accountLoads
)
if
err
!=
nil
{
if
result
,
ok
:=
s
.
tryAcquireByLegacyOrder
(
ctx
,
candidates
,
groupID
,
sessionHash
,
preferOAuth
,
sessionUUID
);
ok
{
if
result
,
ok
:=
s
.
tryAcquireByLegacyOrder
(
ctx
,
candidates
,
groupID
,
sessionHash
,
preferOAuth
);
ok
{
return
result
,
nil
}
}
else
{
...
...
@@ -1258,7 +1527,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
item
.
account
.
ID
,
item
.
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
session
UUID
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
session
Hash
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
...
...
@@ -1276,8 +1545,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// ============ Layer 3: 兜底排队 ============
sort
AccountsByPriorityAndLastUsed
(
candidates
,
preferOAuth
)
s
.
sort
CandidatesForFallback
(
candidates
,
preferOAuth
,
cfg
.
FallbackSelectionMode
)
for
_
,
acc
:=
range
candidates
{
// 会话数量限制检查(等待计划也需要占用会话配额)
if
!
s
.
checkAndRegisterSession
(
ctx
,
acc
,
sessionHash
)
{
continue
// 会话限制已满,尝试下一个账号
}
return
&
AccountSelectionResult
{
Account
:
acc
,
WaitPlan
:
&
AccountWaitPlan
{
...
...
@@ -1291,7 +1564,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return
nil
,
errors
.
New
(
"no available accounts"
)
}
func
(
s
*
GatewayService
)
tryAcquireByLegacyOrder
(
ctx
context
.
Context
,
candidates
[]
*
Account
,
groupID
*
int64
,
sessionHash
string
,
preferOAuth
bool
,
sessionUUID
string
)
(
*
AccountSelectionResult
,
bool
)
{
func
(
s
*
GatewayService
)
tryAcquireByLegacyOrder
(
ctx
context
.
Context
,
candidates
[]
*
Account
,
groupID
*
int64
,
sessionHash
string
,
preferOAuth
bool
)
(
*
AccountSelectionResult
,
bool
)
{
ordered
:=
append
([]
*
Account
(
nil
),
candidates
...
)
sortAccountsByPriorityAndLastUsed
(
ordered
,
preferOAuth
)
...
...
@@ -1299,7 +1572,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
acc
.
ID
,
acc
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
acc
,
session
UUID
)
{
if
!
s
.
checkAndRegisterSession
(
ctx
,
acc
,
session
Hash
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
...
...
@@ -1456,7 +1729,24 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
func
(
s
*
GatewayService
)
listSchedulableAccounts
(
ctx
context
.
Context
,
groupID
*
int64
,
platform
string
,
hasForcePlatform
bool
)
([]
Account
,
bool
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
return
s
.
schedulerSnapshot
.
ListSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
accounts
,
useMixed
,
err
:=
s
.
schedulerSnapshot
.
ListSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
==
nil
{
slog
.
Debug
(
"account_scheduling_list_snapshot"
,
"group_id"
,
derefGroupID
(
groupID
),
"platform"
,
platform
,
"use_mixed"
,
useMixed
,
"count"
,
len
(
accounts
))
for
_
,
acc
:=
range
accounts
{
slog
.
Debug
(
"account_scheduling_account_detail"
,
"account_id"
,
acc
.
ID
,
"name"
,
acc
.
Name
,
"platform"
,
acc
.
Platform
,
"type"
,
acc
.
Type
,
"status"
,
acc
.
Status
,
"tls_fingerprint"
,
acc
.
IsTLSFingerprintEnabled
())
}
}
return
accounts
,
useMixed
,
err
}
useMixed
:=
(
platform
==
PlatformAnthropic
||
platform
==
PlatformGemini
)
&&
!
hasForcePlatform
if
useMixed
{
...
...
@@ -1469,6 +1759,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
if
err
!=
nil
{
slog
.
Debug
(
"account_scheduling_list_failed"
,
"group_id"
,
derefGroupID
(
groupID
),
"platform"
,
platform
,
"error"
,
err
)
return
nil
,
useMixed
,
err
}
filtered
:=
make
([]
Account
,
0
,
len
(
accounts
))
...
...
@@ -1478,6 +1772,20 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
}
filtered
=
append
(
filtered
,
acc
)
}
slog
.
Debug
(
"account_scheduling_list_mixed"
,
"group_id"
,
derefGroupID
(
groupID
),
"platform"
,
platform
,
"raw_count"
,
len
(
accounts
),
"filtered_count"
,
len
(
filtered
))
for
_
,
acc
:=
range
filtered
{
slog
.
Debug
(
"account_scheduling_account_detail"
,
"account_id"
,
acc
.
ID
,
"name"
,
acc
.
Name
,
"platform"
,
acc
.
Platform
,
"type"
,
acc
.
Type
,
"status"
,
acc
.
Status
,
"tls_fingerprint"
,
acc
.
IsTLSFingerprintEnabled
())
}
return
filtered
,
useMixed
,
nil
}
...
...
@@ -1492,8 +1800,25 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts
,
err
=
s
.
accountRepo
.
ListSchedulableByPlatform
(
ctx
,
platform
)
}
if
err
!=
nil
{
slog
.
Debug
(
"account_scheduling_list_failed"
,
"group_id"
,
derefGroupID
(
groupID
),
"platform"
,
platform
,
"error"
,
err
)
return
nil
,
useMixed
,
err
}
slog
.
Debug
(
"account_scheduling_list_single"
,
"group_id"
,
derefGroupID
(
groupID
),
"platform"
,
platform
,
"count"
,
len
(
accounts
))
for
_
,
acc
:=
range
accounts
{
slog
.
Debug
(
"account_scheduling_account_detail"
,
"account_id"
,
acc
.
ID
,
"name"
,
acc
.
Name
,
"platform"
,
acc
.
Platform
,
"type"
,
acc
.
Type
,
"status"
,
acc
.
Status
,
"tls_fingerprint"
,
acc
.
IsTLSFingerprintEnabled
())
}
return
accounts
,
useMixed
,
nil
}
...
...
@@ -1559,12 +1884,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
// 缓存未命中,从数据库查询
{
var
startTime
time
.
Time
if
account
.
SessionWindowStart
!=
nil
{
startTime
=
*
account
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime
:=
account
.
GetCurrentWindowStartTime
()
stats
,
err
:=
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
account
.
ID
,
startTime
)
if
err
!=
nil
{
...
...
@@ -1597,15 +1918,16 @@ checkSchedulability:
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash)
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
func
(
s
*
GatewayService
)
checkAndRegisterSession
(
ctx
context
.
Context
,
account
*
Account
,
session
UU
ID
string
)
bool
{
func
(
s
*
GatewayService
)
checkAndRegisterSession
(
ctx
context
.
Context
,
account
*
Account
,
sessionID
string
)
bool
{
// 只检查 Anthropic OAuth/SetupToken 账号
if
!
account
.
IsAnthropicOAuthOrSetupToken
()
{
return
true
}
maxSessions
:=
account
.
GetMaxSessions
()
if
maxSessions
<=
0
||
session
UU
ID
==
""
{
if
maxSessions
<=
0
||
sessionID
==
""
{
return
true
// 未启用会话限制或无会话ID
}
...
...
@@ -1615,7 +1937,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
idleTimeout
:=
time
.
Duration
(
account
.
GetSessionIdleTimeoutMinutes
())
*
time
.
Minute
allowed
,
err
:=
s
.
sessionLimitCache
.
RegisterSession
(
ctx
,
account
.
ID
,
session
UU
ID
,
maxSessions
,
idleTimeout
)
allowed
,
err
:=
s
.
sessionLimitCache
.
RegisterSession
(
ctx
,
account
.
ID
,
sessionID
,
maxSessions
,
idleTimeout
)
if
err
!=
nil
{
// 失败开放:缓存错误时允许通过
return
true
...
...
@@ -1623,18 +1945,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
return
allowed
}
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
// 格式: user_{64位hex}_account__session_{uuid}
func
extractSessionUUID
(
metadataUserID
string
)
string
{
if
metadataUserID
==
""
{
return
""
}
if
match
:=
sessionIDRegex
.
FindStringSubmatch
(
metadataUserID
);
len
(
match
)
>
1
{
return
match
[
1
]
}
return
""
}
func
(
s
*
GatewayService
)
getSchedulableAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Account
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
return
s
.
schedulerSnapshot
.
GetAccount
(
ctx
,
accountID
)
...
...
@@ -1664,6 +1974,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
})
}
// sortCandidatesForFallback 根据配置选择排序策略
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
func
(
s
*
GatewayService
)
sortCandidatesForFallback
(
accounts
[]
*
Account
,
preferOAuth
bool
,
mode
string
)
{
if
mode
==
"random"
{
// 先按优先级排序,然后在同优先级内随机打乱
sortAccountsByPriorityOnly
(
accounts
,
preferOAuth
)
shuffleWithinPriority
(
accounts
)
}
else
{
// 默认按最后使用时间排序
sortAccountsByPriorityAndLastUsed
(
accounts
,
preferOAuth
)
}
}
// sortAccountsByPriorityOnly 仅按优先级排序
func
sortAccountsByPriorityOnly
(
accounts
[]
*
Account
,
preferOAuth
bool
)
{
sort
.
SliceStable
(
accounts
,
func
(
i
,
j
int
)
bool
{
a
,
b
:=
accounts
[
i
],
accounts
[
j
]
if
a
.
Priority
!=
b
.
Priority
{
return
a
.
Priority
<
b
.
Priority
}
if
preferOAuth
&&
a
.
Type
!=
b
.
Type
{
return
a
.
Type
==
AccountTypeOAuth
}
return
false
})
}
// shuffleWithinPriority 在同优先级内随机打乱顺序
func
shuffleWithinPriority
(
accounts
[]
*
Account
)
{
if
len
(
accounts
)
<=
1
{
return
}
r
:=
mathrand
.
New
(
mathrand
.
NewSource
(
time
.
Now
()
.
UnixNano
()))
start
:=
0
for
start
<
len
(
accounts
)
{
priority
:=
accounts
[
start
]
.
Priority
end
:=
start
+
1
for
end
<
len
(
accounts
)
&&
accounts
[
end
]
.
Priority
==
priority
{
end
++
}
// 对 [start, end) 范围内的账户随机打乱
if
end
-
start
>
1
{
r
.
Shuffle
(
end
-
start
,
func
(
i
,
j
int
)
{
accounts
[
start
+
i
],
accounts
[
start
+
j
]
=
accounts
[
start
+
j
],
accounts
[
start
+
i
]
})
}
start
=
end
}
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func
(
s
*
GatewayService
)
selectAccountForModelWithPlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
)
(
*
Account
,
error
)
{
preferOAuth
:=
platform
==
PlatformGemini
...
...
@@ -1687,7 +2047,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
==
nil
{
clearSticky
:=
shouldClearStickySession
(
account
)
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
...
...
@@ -1699,6 +2064,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
}
}
}
// 2) Select an account from the routed candidates.
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
...
...
@@ -1784,7 +2150,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
==
nil
{
clearSticky
:=
shouldClearStickySession
(
account
)
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
...
...
@@ -1793,6 +2164,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
}
}
}
// 2. 获取可调度账号列表(单平台)
if
!
accountsLoaded
{
...
...
@@ -1888,7 +2260,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
==
nil
{
clearSticky
:=
shouldClearStickySession
(
account
)
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
...
...
@@ -1902,6 +2279,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
}
}
}
// 2) Select an account from the routed candidates.
var
err
error
...
...
@@ -1987,7 +2365,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
==
nil
{
clearSticky
:=
shouldClearStickySession
(
account
)
if
clearSticky
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
}
if
!
clearSticky
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
...
...
@@ -1998,6 +2381,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
}
}
}
// 2. 获取可调度账号列表
if
!
accountsLoaded
{
...
...
@@ -2247,6 +2631,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
"text"
:
claudeCodeSystemPrompt
,
"cache_control"
:
map
[
string
]
string
{
"type"
:
"ephemeral"
},
}
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
// banner, it also prefixes the next system instruction with the same banner plus
// a blank line. This helps when upstream concatenates system instructions.
claudeCodePrefix
:=
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
var
newSystem
[]
any
...
...
@@ -2254,19 +2642,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
case
nil
:
newSystem
=
[]
any
{
claudeCodeBlock
}
case
string
:
if
v
==
""
||
v
==
claudeCodeSystemPrompt
{
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
if
strings
.
TrimSpace
(
v
)
==
""
||
strings
.
TrimSpace
(
v
)
==
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
{
newSystem
=
[]
any
{
claudeCodeBlock
}
}
else
{
newSystem
=
[]
any
{
claudeCodeBlock
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
v
}}
// Mirror opencode behavior: keep the banner as a separate system entry,
// but also prefix the next system text with the banner.
merged
:=
v
if
!
strings
.
HasPrefix
(
v
,
claudeCodePrefix
)
{
merged
=
claudeCodePrefix
+
"
\n\n
"
+
v
}
newSystem
=
[]
any
{
claudeCodeBlock
,
map
[
string
]
any
{
"type"
:
"text"
,
"text"
:
merged
}}
}
case
[]
any
:
newSystem
=
make
([]
any
,
0
,
len
(
v
)
+
1
)
newSystem
=
append
(
newSystem
,
claudeCodeBlock
)
prefixedNext
:=
false
for
_
,
item
:=
range
v
{
if
m
,
ok
:=
item
.
(
map
[
string
]
any
);
ok
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
text
==
claudeCodeSystemPrompt
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
text
)
==
strings
.
TrimSpace
(
claudeCodeSystemPrompt
)
{
continue
}
// Prefix the first subsequent text system block once.
if
!
prefixedNext
{
if
blockType
,
_
:=
m
[
"type"
]
.
(
string
);
blockType
==
"text"
{
if
text
,
ok
:=
m
[
"text"
]
.
(
string
);
ok
&&
strings
.
TrimSpace
(
text
)
!=
""
&&
!
strings
.
HasPrefix
(
text
,
claudeCodePrefix
)
{
m
[
"text"
]
=
claudeCodePrefix
+
"
\n\n
"
+
text
prefixedNext
=
true
}
}
}
}
newSystem
=
append
(
newSystem
,
item
)
}
...
...
@@ -2524,6 +2929,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
proxyURL
=
account
.
Proxy
.
URL
()
}
// 调试日志:记录即将转发的账号信息
log
.
Printf
(
"[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s"
,
account
.
ID
,
account
.
Name
,
account
.
Platform
,
account
.
Type
,
account
.
IsTLSFingerprintEnabled
(),
proxyURL
)
// 重试循环
var
resp
*
http
.
Response
retryStart
:=
time
.
Now
()
...
...
@@ -2537,7 +2946,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 发送请求
resp
,
err
=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
resp
,
err
=
s
.
httpUpstream
.
Do
WithTLS
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
err
!=
nil
{
if
resp
!=
nil
&&
resp
.
Body
!=
nil
{
_
=
resp
.
Body
.
Close
()
...
...
@@ -2611,7 +3020,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
WithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
retryErr
==
nil
{
if
retryResp
.
StatusCode
<
400
{
log
.
Printf
(
"Account %d: signature error retry succeeded (thinking downgraded)"
,
account
.
ID
)
...
...
@@ -2643,7 +3052,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody2
:=
FilterSignatureSensitiveBlocksForRetry
(
body
)
retryReq2
,
buildErr2
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
filteredBody2
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
buildErr2
==
nil
{
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
Do
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
retryResp2
,
retryErr2
:=
s
.
httpUpstream
.
Do
WithTLS
(
retryReq2
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
retryErr2
==
nil
{
resp
=
retryResp2
break
...
...
@@ -2758,6 +3167,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
// 调试日志:打印重试耗尽后的错误响应
log
.
Printf
(
"[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s"
,
account
.
ID
,
account
.
Name
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
truncateString
(
string
(
respBody
),
1000
))
s
.
handleRetryExhaustedSideEffects
(
ctx
,
resp
,
account
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
...
...
@@ -2785,6 +3198,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_
=
resp
.
Body
.
Close
()
resp
.
Body
=
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
))
// 调试日志:打印上游错误响应
log
.
Printf
(
"[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s"
,
account
.
ID
,
account
.
Name
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
truncateString
(
string
(
respBody
),
1000
))
s
.
handleFailoverSideEffects
(
ctx
,
resp
,
account
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
...
...
@@ -2902,11 +3319,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
clientHeaders
:=
http
.
Header
{}
if
c
!=
nil
&&
c
.
Request
!=
nil
{
clientHeaders
=
c
.
Request
.
Header
}
// OAuth账号:应用统一指纹
var
fingerprint
*
Fingerprint
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
// 1. 获取或创建指纹(包含随机生成的ClientID)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
err
!=
nil
{
log
.
Printf
(
"Warning: failed to get fingerprint for account %d: %v"
,
account
.
ID
,
err
)
// 失败时降级为透传原始headers
...
...
@@ -2914,9 +3336,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
fingerprint
=
fp
// 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid)
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
if
newBody
,
err
:=
s
.
identityService
.
RewriteUserID
(
body
,
account
.
ID
,
accountUUID
,
fp
.
ClientID
);
err
==
nil
&&
len
(
newBody
)
>
0
{
if
newBody
,
err
:=
s
.
identityService
.
RewriteUserID
WithMasking
(
ctx
,
body
,
account
,
accountUUID
,
fp
.
ClientID
);
err
==
nil
&&
len
(
newBody
)
>
0
{
body
=
newBody
}
}
...
...
@@ -2936,7 +3359,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
// 白名单透传headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
for
key
,
values
:=
range
c
lient
Header
s
{
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
...
...
@@ -2964,12 +3387,18 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
if
tokenType
==
"oauth"
{
if
mimicClaudeCode
{
// 非 Claude Code 客户端:按 Claude Code 规则生成 beta header
if
requestHasTools
(
body
)
{
req
.
Header
.
Set
(
"anthropic-beta"
,
claude
.
MessageBetaHeaderWithTools
)
}
else
{
req
.
Header
.
Set
(
"anthropic-beta"
,
claude
.
MessageBetaHeaderNoTools
)
}
// 非 Claude Code 客户端:按 opencode 的策略处理:
// - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app)
// - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在
applyClaudeCodeMimicHeaders
(
req
,
reqStream
)
incomingBeta
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas
:=
[]
string
{
claude
.
BetaOAuth
,
claude
.
BetaInterleavedThinking
}
drop
:=
map
[
string
]
struct
{}{
claude
.
BetaClaudeCode
:
{}}
req
.
Header
.
Set
(
"anthropic-beta"
,
mergeAnthropicBetaDropping
(
requiredBetas
,
incomingBeta
,
drop
))
}
else
{
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
...
...
@@ -2984,6 +3413,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// Always capture a compact fingerprint line for later error diagnostics.
// We only print it when needed (or when the explicit debug flag is enabled).
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
if
s
.
debugClaudeMimicEnabled
()
{
logClaudeMimicDebug
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
}
return
req
,
nil
}
...
...
@@ -3045,20 +3483,6 @@ func requestNeedsBetaFeatures(body []byte) bool {
return
false
}
func
requestHasTools
(
body
[]
byte
)
bool
{
tools
:=
gjson
.
GetBytes
(
body
,
"tools"
)
if
!
tools
.
Exists
()
{
return
false
}
if
tools
.
IsArray
()
{
return
len
(
tools
.
Array
())
>
0
}
if
tools
.
IsObject
()
{
return
len
(
tools
.
Map
())
>
0
}
return
false
}
func
defaultAPIKeyBetaHeader
(
body
[]
byte
)
string
{
modelID
:=
gjson
.
GetBytes
(
body
,
"model"
)
.
String
()
if
strings
.
Contains
(
strings
.
ToLower
(
modelID
),
"haiku"
)
{
...
...
@@ -3087,6 +3511,73 @@ func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) {
}
}
func
mergeAnthropicBeta
(
required
[]
string
,
incoming
string
)
string
{
seen
:=
make
(
map
[
string
]
struct
{},
len
(
required
)
+
8
)
out
:=
make
([]
string
,
0
,
len
(
required
)
+
8
)
add
:=
func
(
v
string
)
{
v
=
strings
.
TrimSpace
(
v
)
if
v
==
""
{
return
}
if
_
,
ok
:=
seen
[
v
];
ok
{
return
}
seen
[
v
]
=
struct
{}{}
out
=
append
(
out
,
v
)
}
for
_
,
r
:=
range
required
{
add
(
r
)
}
for
_
,
p
:=
range
strings
.
Split
(
incoming
,
","
)
{
add
(
p
)
}
return
strings
.
Join
(
out
,
","
)
}
func
mergeAnthropicBetaDropping
(
required
[]
string
,
incoming
string
,
drop
map
[
string
]
struct
{})
string
{
merged
:=
mergeAnthropicBeta
(
required
,
incoming
)
if
merged
==
""
||
len
(
drop
)
==
0
{
return
merged
}
out
:=
make
([]
string
,
0
,
8
)
for
_
,
p
:=
range
strings
.
Split
(
merged
,
","
)
{
p
=
strings
.
TrimSpace
(
p
)
if
p
==
""
{
continue
}
if
_
,
ok
:=
drop
[
p
];
ok
{
continue
}
out
=
append
(
out
,
p
)
}
return
strings
.
Join
(
out
,
","
)
}
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials.
func
applyClaudeCodeMimicHeaders
(
req
*
http
.
Request
,
isStream
bool
)
{
if
req
==
nil
{
return
}
// Start with the standard defaults (fill missing).
applyClaudeOAuthHeaderDefaults
(
req
,
isStream
)
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
for
key
,
value
:=
range
claude
.
DefaultHeaders
{
if
value
==
""
{
continue
}
req
.
Header
.
Set
(
key
,
value
)
}
// Real Claude CLI uses Accept: application/json (even for streaming).
req
.
Header
.
Set
(
"accept"
,
"application/json"
)
if
isStream
{
req
.
Header
.
Set
(
"x-stainless-helper-method"
,
"stream"
)
}
}
func
truncateForLog
(
b
[]
byte
,
maxBytes
int
)
string
{
if
maxBytes
<=
0
{
maxBytes
=
2048
...
...
@@ -3183,9 +3674,27 @@ func extractUpstreamErrorMessage(body []byte) string {
func
(
s
*
GatewayService
)
handleErrorResponse
(
ctx
context
.
Context
,
resp
*
http
.
Response
,
c
*
gin
.
Context
,
account
*
Account
)
(
*
ForwardResult
,
error
)
{
body
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
// 调试日志:打印上游错误响应
log
.
Printf
(
"[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s"
,
account
.
ID
,
account
.
Name
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
truncateString
(
string
(
body
),
1000
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
// Print a compact upstream request fingerprint when we hit the Claude Code OAuth
// credential scope error. This avoids requiring env-var tweaks in a fixed deploy.
if
isClaudeCodeCredentialScopeError
(
upstreamMsg
)
&&
c
!=
nil
{
if
v
,
ok
:=
c
.
Get
(
claudeMimicDebugInfoKey
);
ok
{
if
line
,
ok
:=
v
.
(
string
);
ok
&&
strings
.
TrimSpace
(
line
)
!=
""
{
log
.
Printf
(
"[ClaudeMimicDebugOnError] status=%d request_id=%s %s"
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
line
,
)
}
}
}
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
...
...
@@ -3315,6 +3824,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
isClaudeCodeCredentialScopeError
(
upstreamMsg
)
&&
c
!=
nil
{
if
v
,
ok
:=
c
.
Get
(
claudeMimicDebugInfoKey
);
ok
{
if
line
,
ok
:=
v
.
(
string
);
ok
&&
strings
.
TrimSpace
(
line
)
!=
""
{
log
.
Printf
(
"[ClaudeMimicDebugOnError] status=%d request_id=%s %s"
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
line
,
)
}
}
}
upstreamDetail
:=
""
if
s
.
cfg
!=
nil
&&
s
.
cfg
.
Gateway
.
LogUpstreamErrorBody
{
maxBytes
:=
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
...
...
@@ -3860,17 +4382,19 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
`json:"usage"`
}
if
json
.
Unmarshal
([]
byte
(
data
),
&
msgDelta
)
==
nil
&&
msgDelta
.
Type
==
"message_delta"
{
// output_tokens 总是从 message_delta 获取
usage
.
OutputTokens
=
msgDelta
.
Usage
.
OutputTokens
// 如果 message_start 中没有值,则从 message_delta 获取(兼容GLM等API)
if
usage
.
InputTokens
==
0
{
// message_delta 仅覆盖存在且非0的字段
// 避免覆盖 message_start 中已有的值(如 input_tokens)
// Claude API 的 message_delta 通常只包含 output_tokens
if
msgDelta
.
Usage
.
InputTokens
>
0
{
usage
.
InputTokens
=
msgDelta
.
Usage
.
InputTokens
}
if
usage
.
CacheCreationInputTokens
==
0
{
if
msgDelta
.
Usage
.
OutputTokens
>
0
{
usage
.
OutputTokens
=
msgDelta
.
Usage
.
OutputTokens
}
if
msgDelta
.
Usage
.
CacheCreationInputTokens
>
0
{
usage
.
CacheCreationInputTokens
=
msgDelta
.
Usage
.
CacheCreationInputTokens
}
if
u
sage
.
CacheReadInputTokens
==
0
{
if
msgDelta
.
U
sage
.
CacheReadInputTokens
>
0
{
usage
.
CacheReadInputTokens
=
msgDelta
.
Usage
.
CacheReadInputTokens
}
}
...
...
@@ -4171,7 +4695,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 发送请求
resp
,
err
:=
s
.
httpUpstream
.
Do
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
resp
,
err
:=
s
.
httpUpstream
.
Do
WithTLS
(
upstreamReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
err
!=
nil
{
setOpsUpstreamError
(
c
,
0
,
sanitizeUpstreamErrorMessage
(
err
.
Error
()),
""
)
s
.
countTokensError
(
c
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Request failed"
)
...
...
@@ -4193,7 +4717,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
filteredBody
:=
FilterThinkingBlocksForRetry
(
body
)
retryReq
,
buildErr
:=
s
.
buildCountTokensRequest
(
ctx
,
c
,
account
,
filteredBody
,
token
,
tokenType
,
reqModel
,
shouldMimicClaudeCode
)
if
buildErr
==
nil
{
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
WithTLS
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
,
account
.
IsTLSFingerprintEnabled
()
)
if
retryErr
==
nil
{
resp
=
retryResp
respBody
,
err
=
io
.
ReadAll
(
resp
.
Body
)
...
...
@@ -4270,13 +4794,19 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
clientHeaders
:=
http
.
Header
{}
if
c
!=
nil
&&
c
.
Request
!=
nil
{
clientHeaders
=
c
.
Request
.
Header
}
// OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
err
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
err
==
nil
{
accountUUID
:=
account
.
GetExtraString
(
"account_uuid"
)
if
accountUUID
!=
""
&&
fp
.
ClientID
!=
""
{
if
newBody
,
err
:=
s
.
identityService
.
RewriteUserID
(
body
,
account
.
ID
,
accountUUID
,
fp
.
ClientID
);
err
==
nil
&&
len
(
newBody
)
>
0
{
if
newBody
,
err
:=
s
.
identityService
.
RewriteUserID
WithMasking
(
ctx
,
body
,
account
,
accountUUID
,
fp
.
ClientID
);
err
==
nil
&&
len
(
newBody
)
>
0
{
body
=
newBody
}
}
...
...
@@ -4296,7 +4826,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// 白名单透传 headers
for
key
,
values
:=
range
c
.
Request
.
Header
{
for
key
,
values
:=
range
c
lient
Header
s
{
lowerKey
:=
strings
.
ToLower
(
key
)
if
allowedHeaders
[
lowerKey
]
{
for
_
,
v
:=
range
values
{
...
...
@@ -4307,7 +4837,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用指纹到请求头
if
account
.
IsOAuth
()
&&
s
.
identityService
!=
nil
{
fp
,
_
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
.
Request
.
Header
)
fp
,
_
:=
s
.
identityService
.
GetOrCreateFingerprint
(
ctx
,
account
.
ID
,
c
lient
Header
s
)
if
fp
!=
nil
{
s
.
identityService
.
ApplyFingerprint
(
req
,
fp
)
}
...
...
@@ -4327,7 +4857,11 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if
tokenType
==
"oauth"
{
if
mimicClaudeCode
{
req
.
Header
.
Set
(
"anthropic-beta"
,
claude
.
CountTokensBetaHeader
)
applyClaudeCodeMimicHeaders
(
req
,
false
)
incomingBeta
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
requiredBetas
:=
[]
string
{
claude
.
BetaClaudeCode
,
claude
.
BetaOAuth
,
claude
.
BetaInterleavedThinking
,
claude
.
BetaTokenCounting
}
req
.
Header
.
Set
(
"anthropic-beta"
,
mergeAnthropicBeta
(
requiredBetas
,
incomingBeta
))
}
else
{
clientBetaHeader
:=
req
.
Header
.
Get
(
"anthropic-beta"
)
if
clientBetaHeader
==
""
{
...
...
@@ -4349,6 +4883,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
if
c
!=
nil
&&
tokenType
==
"oauth"
{
c
.
Set
(
claudeMimicDebugInfoKey
,
buildClaudeMimicDebugLine
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
))
}
if
s
.
debugClaudeMimicEnabled
()
{
logClaudeMimicDebug
(
req
,
body
,
account
,
tokenType
,
mimicClaudeCode
)
}
return
req
,
nil
}
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
2fe8932c
...
...
@@ -82,145 +82,276 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
}
func
(
s
*
GeminiMessagesCompatService
)
SelectAccountForModelWithExclusions
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
Account
,
error
)
{
// 1. 确定目标平台和调度模式
// Determine target platform and scheduling mode
platform
,
useMixedScheduling
,
hasForcePlatform
,
err
:=
s
.
resolvePlatformAndSchedulingMode
(
ctx
,
groupID
)
if
err
!=
nil
{
return
nil
,
err
}
cacheKey
:=
"gemini:"
+
sessionHash
// 2. 尝试粘性会话命中
// Try sticky session hit
if
account
:=
s
.
tryStickySessionHit
(
ctx
,
groupID
,
sessionHash
,
cacheKey
,
requestedModel
,
excludedIDs
,
platform
,
useMixedScheduling
);
account
!=
nil
{
return
account
,
nil
}
// 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
// Query schedulable accounts (force platform mode: try group first, fallback to all)
accounts
,
err
:=
s
.
listSchedulableAccountsOnce
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if
len
(
accounts
)
==
0
&&
groupID
!=
nil
&&
hasForcePlatform
{
accounts
,
err
=
s
.
listSchedulableAccountsOnce
(
ctx
,
nil
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
}
// 4. 按优先级 + LRU 选择最佳账号
// Select best account by priority + LRU
selected
:=
s
.
selectBestGeminiAccount
(
ctx
,
accounts
,
requestedModel
,
excludedIDs
,
platform
,
useMixedScheduling
)
if
selected
==
nil
{
if
requestedModel
!=
""
{
return
nil
,
fmt
.
Errorf
(
"no available Gemini accounts supporting model: %s"
,
requestedModel
)
}
return
nil
,
errors
.
New
(
"no available Gemini accounts"
)
}
// 5. 设置粘性会话绑定
// Set sticky session binding
if
sessionHash
!=
""
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
,
selected
.
ID
,
geminiStickySessionTTL
)
}
return
selected
,
nil
}
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
// 返回:平台名称、是否使用混合调度、是否强制平台、错误。
//
// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode.
// Returns: platform name, whether to use mixed scheduling, whether force platform, error.
func
(
s
*
GeminiMessagesCompatService
)
resolvePlatformAndSchedulingMode
(
ctx
context
.
Context
,
groupID
*
int64
)
(
platform
string
,
useMixedScheduling
bool
,
hasForcePlatform
bool
,
err
error
)
{
// 优先检查 context 中的强制平台(/antigravity 路由)
var
platform
string
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
if
hasForcePlatform
&&
forcePlatform
!=
""
{
platform
=
forcePlatform
}
else
if
groupID
!=
nil
{
return
forcePlatform
,
false
,
true
,
nil
}
if
groupID
!=
nil
{
// 根据分组 platform 决定查询哪种账号
var
group
*
Group
if
ctxGroup
,
ok
:=
ctx
.
Value
(
ctxkey
.
Group
)
.
(
*
Group
);
ok
&&
IsGroupContextValid
(
ctxGroup
)
&&
ctxGroup
.
ID
==
*
groupID
{
group
=
ctxGroup
}
else
{
var
err
error
group
,
err
=
s
.
groupRepo
.
GetByIDLite
(
ctx
,
*
groupID
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
return
""
,
false
,
false
,
fmt
.
Errorf
(
"get group failed: %w"
,
err
)
}
}
platform
=
group
.
Platform
}
else
{
// 无分组时只使用原生 gemini 平台
platform
=
PlatformGemini
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
return
group
.
Platform
,
group
.
Platform
==
PlatformGemini
,
false
,
nil
}
//
gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
useMixedScheduling
:=
platform
==
PlatformGemini
&&
!
hasForcePlatform
//
无分组时只使用原生 gemini 平台
return
PlatformGemini
,
true
,
false
,
nil
}
cacheKey
:=
"gemini:"
+
sessionHash
// tryStickySessionHit 尝试从粘性会话获取账号。
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account unavailable.
func
(
s
*
GeminiMessagesCompatService
)
tryStickySessionHit
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
,
cacheKey
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
,
useMixedScheduling
bool
,
)
*
Account
{
if
sessionHash
==
""
{
return
nil
}
if
sessionHash
!=
""
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
)
if
err
==
nil
&&
accountID
>
0
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
valid
:=
false
if
account
.
Platform
==
platform
{
valid
=
true
}
else
if
useMixedScheduling
&&
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
()
{
valid
=
true
if
err
!=
nil
||
accountID
<=
0
{
return
nil
}
if
valid
{
usable
:=
true
if
s
.
rateLimitService
!=
nil
&&
requestedModel
!=
""
{
ok
,
err
:=
s
.
rateLimitService
.
PreCheckUsage
(
ctx
,
account
,
requestedModel
)
if
_
,
excluded
:=
excludedIDs
[
accountID
];
excluded
{
return
nil
}
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
if
err
!=
nil
{
log
.
Printf
(
"[Gemini PreCheck] Account %d precheck error: %v"
,
account
.
ID
,
err
)
return
nil
}
if
!
ok
{
usable
=
false
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if
shouldClearStickySession
(
account
)
{
_
=
s
.
cache
.
DeleteSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
)
return
nil
}
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if
!
s
.
isAccountUsableForRequest
(
ctx
,
account
,
requestedModel
,
platform
,
useMixedScheduling
)
{
return
nil
}
if
usable
{
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
,
geminiStickySessionTTL
)
return
account
,
nil
return
account
}
// isAccountUsableForRequest 检查账号是否可用于当前请求。
// 验证:模型调度、模型支持、平台匹配、速率限制预检。
//
// isAccountUsableForRequest checks if account is usable for current request.
// Validates: model scheduling, model support, platform matching, rate limit precheck.
func
(
s
*
GeminiMessagesCompatService
)
isAccountUsableForRequest
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
,
platform
string
,
useMixedScheduling
bool
,
)
bool
{
// 检查模型调度能力
// Check model scheduling capability
if
!
account
.
IsSchedulableForModel
(
requestedModel
)
{
return
false
}
// 检查模型支持
// Check model support
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
account
,
requestedModel
)
{
return
false
}
// 检查平台匹配
// Check platform matching
if
!
s
.
isAccountValidForPlatform
(
account
,
platform
,
useMixedScheduling
)
{
return
false
}
// 速率限制预检
// Rate limit precheck
if
!
s
.
passesRateLimitPreCheck
(
ctx
,
account
,
requestedModel
)
{
return
false
}
return
true
}
// isAccountValidForPlatform 检查账号是否匹配目标平台。
// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。
//
// isAccountValidForPlatform checks if account matches target platform.
// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling.
func
(
s
*
GeminiMessagesCompatService
)
isAccountValidForPlatform
(
account
*
Account
,
platform
string
,
useMixedScheduling
bool
)
bool
{
if
account
.
Platform
==
platform
{
return
true
}
if
useMixedScheduling
&&
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
()
{
return
true
}
return
false
}
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
accounts
,
err
:=
s
.
listSchedulableAccountsOnce
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
// passesRateLimitPreCheck 执行速率限制预检。
// 返回 true 表示通过预检或无需预检。
//
// passesRateLimitPreCheck performs rate limit precheck.
// Returns true if passed or precheck not required.
func
(
s
*
GeminiMessagesCompatService
)
passesRateLimitPreCheck
(
ctx
context
.
Context
,
account
*
Account
,
requestedModel
string
)
bool
{
if
s
.
rateLimitService
==
nil
||
requestedModel
==
""
{
return
true
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if
len
(
accounts
)
==
0
&&
groupID
!=
nil
&&
hasForcePlatform
{
accounts
,
err
=
s
.
listSchedulableAccountsOnce
(
ctx
,
nil
,
platform
,
hasForcePlatform
)
ok
,
err
:=
s
.
rateLimitService
.
PreCheckUsage
(
ctx
,
account
,
requestedModel
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
log
.
Printf
(
"[Gemini PreCheck] Account %d precheck error: %v"
,
account
.
ID
,
err
)
}
return
ok
}
// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。
// 返回 nil 表示无可用账号。
//
// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred).
// Returns nil if no available account.
func
(
s
*
GeminiMessagesCompatService
)
selectBestGeminiAccount
(
ctx
context
.
Context
,
accounts
[]
Account
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
,
useMixedScheduling
bool
,
)
*
Account
{
var
selected
*
Account
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
// 跳过被排除的账号
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
// 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling
// 非混合调度模式(antigravity 分组):不需要过滤
if
useMixedScheduling
&&
acc
.
Platform
==
PlatformAntigravity
&&
!
acc
.
IsMixedSchedulingEnabled
()
{
continue
}
if
!
acc
.
IsSchedulableForModel
(
requestedModel
)
{
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
if
s
.
rateLimitService
!=
nil
&&
requestedModel
!=
""
{
ok
,
err
:=
s
.
rateLimitService
.
PreCheckUsage
(
ctx
,
acc
,
requestedModel
)
if
err
!=
nil
{
log
.
Printf
(
"[Gemini PreCheck] Account %d precheck error: %v"
,
acc
.
ID
,
err
)
}
if
!
ok
{
// 检查账号是否可用于当前请求
if
!
s
.
isAccountUsableForRequest
(
ctx
,
acc
,
requestedModel
,
platform
,
useMixedScheduling
)
{
continue
}
}
// 选择最佳账号
if
selected
==
nil
{
selected
=
acc
continue
}
if
acc
.
Priority
<
selected
.
Priority
{
selected
=
acc
}
else
if
acc
.
Priority
==
selected
.
Priority
{
switch
{
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
!=
nil
:
selected
=
acc
case
acc
.
LastUsedAt
!=
nil
&&
selected
.
LastUsedAt
==
nil
:
// keep selected (never used is preferred)
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
==
nil
:
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
if
acc
.
Type
==
AccountTypeOAuth
&&
selected
.
Type
!=
AccountTypeOAuth
{
selected
=
acc
}
default
:
if
acc
.
LastUsedAt
.
Before
(
*
selected
.
LastUsedAt
)
{
if
s
.
isBetterGeminiAccount
(
acc
,
selected
)
{
selected
=
acc
}
}
}
}
if
selected
==
nil
{
if
requestedModel
!=
""
{
return
nil
,
fmt
.
Errorf
(
"no available Gemini accounts supporting model: %s"
,
requestedModel
)
return
selected
}
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。
//
// isBetterGeminiAccount checks if candidate is better than current.
// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used.
func
(
s
*
GeminiMessagesCompatService
)
isBetterGeminiAccount
(
candidate
,
current
*
Account
)
bool
{
// 优先级更高(数值更小)
if
candidate
.
Priority
<
current
.
Priority
{
return
true
}
return
nil
,
errors
.
New
(
"no available Gemini accounts"
)
if
candidate
.
Priority
>
current
.
Priority
{
return
false
}
if
sessionHash
!=
""
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
cacheKey
,
selected
.
ID
,
geminiStickySessionTTL
)
// 同优先级,比较最后使用时间
switch
{
case
candidate
.
LastUsedAt
==
nil
&&
current
.
LastUsedAt
!=
nil
:
// candidate 从未使用,优先
return
true
case
candidate
.
LastUsedAt
!=
nil
&&
current
.
LastUsedAt
==
nil
:
// current 从未使用,保持
return
false
case
candidate
.
LastUsedAt
==
nil
&&
current
.
LastUsedAt
==
nil
:
// 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程)
return
candidate
.
Type
==
AccountTypeOAuth
&&
current
.
Type
!=
AccountTypeOAuth
default
:
// 都使用过,选择最久未使用的
return
candidate
.
LastUsedAt
.
Before
(
*
current
.
LastUsedAt
)
}
return
selected
,
nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
...
...
@@ -800,6 +931,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
}
// 图片生成计费
imageCount
:=
0
imageSize
:=
s
.
extractImageSize
(
body
)
if
isImageGenerationModel
(
originalModel
)
{
imageCount
=
1
}
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
*
usage
,
...
...
@@ -807,6 +945,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Stream
:
req
.
Stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
ImageCount
:
imageCount
,
ImageSize
:
imageSize
,
},
nil
}
...
...
@@ -1240,6 +1380,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
usage
=
&
ClaudeUsage
{}
}
// 图片生成计费
imageCount
:=
0
imageSize
:=
s
.
extractImageSize
(
body
)
if
isImageGenerationModel
(
originalModel
)
{
imageCount
=
1
}
return
&
ForwardResult
{
RequestID
:
requestID
,
Usage
:
*
usage
,
...
...
@@ -1247,6 +1394,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Stream
:
stream
,
Duration
:
time
.
Since
(
startTime
),
FirstTokenMs
:
firstTokenMs
,
ImageCount
:
imageCount
,
ImageSize
:
imageSize
,
},
nil
}
...
...
@@ -1841,6 +1990,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
var
last
map
[
string
]
any
var
lastWithParts
map
[
string
]
any
var
collectedTextParts
[]
string
// Collect all text parts for aggregation
usage
:=
&
ClaudeUsage
{}
for
{
...
...
@@ -1852,7 +2002,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
switch
payload
{
case
""
,
"[DONE]"
:
if
payload
==
"[DONE]"
{
return
pickGeminiCollectResult
(
last
,
lastWithParts
),
usage
,
nil
return
mergeCollectedTextParts
(
pickGeminiCollectResult
(
last
,
lastWithParts
),
collectedTextParts
),
usage
,
nil
}
default
:
var
parsed
map
[
string
]
any
...
...
@@ -1871,6 +2021,12 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
if
parts
:=
extractGeminiParts
(
parsed
);
len
(
parts
)
>
0
{
lastWithParts
=
parsed
// Collect text from each part for aggregation
for
_
,
part
:=
range
parts
{
if
text
,
ok
:=
part
[
"text"
]
.
(
string
);
ok
&&
text
!=
""
{
collectedTextParts
=
append
(
collectedTextParts
,
text
)
}
}
}
}
}
...
...
@@ -1885,7 +2041,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
}
return
pickGeminiCollectResult
(
last
,
lastWithParts
),
usage
,
nil
return
mergeCollectedTextParts
(
pickGeminiCollectResult
(
last
,
lastWithParts
),
collectedTextParts
),
usage
,
nil
}
func
pickGeminiCollectResult
(
last
map
[
string
]
any
,
lastWithParts
map
[
string
]
any
)
map
[
string
]
any
{
...
...
@@ -1898,6 +2054,83 @@ func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any)
return
map
[
string
]
any
{}
}
// mergeCollectedTextParts merges all collected text chunks into the final response.
// This fixes the issue where non-streaming responses only returned the last chunk
// instead of the complete aggregated text.
func
mergeCollectedTextParts
(
response
map
[
string
]
any
,
textParts
[]
string
)
map
[
string
]
any
{
if
len
(
textParts
)
==
0
{
return
response
}
// Join all text parts
mergedText
:=
strings
.
Join
(
textParts
,
""
)
// Deep copy response
result
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
response
{
result
[
k
]
=
v
}
// Get or create candidates
candidates
,
ok
:=
result
[
"candidates"
]
.
([]
any
)
if
!
ok
||
len
(
candidates
)
==
0
{
candidates
=
[]
any
{
map
[
string
]
any
{}}
}
// Get first candidate
candidate
,
ok
:=
candidates
[
0
]
.
(
map
[
string
]
any
)
if
!
ok
{
candidate
=
make
(
map
[
string
]
any
)
candidates
[
0
]
=
candidate
}
// Get or create content
content
,
ok
:=
candidate
[
"content"
]
.
(
map
[
string
]
any
)
if
!
ok
{
content
=
map
[
string
]
any
{
"role"
:
"model"
}
candidate
[
"content"
]
=
content
}
// Get existing parts
existingParts
,
ok
:=
content
[
"parts"
]
.
([]
any
)
if
!
ok
{
existingParts
=
[]
any
{}
}
// Find and update first text part, or create new one
newParts
:=
make
([]
any
,
0
,
len
(
existingParts
)
+
1
)
textUpdated
:=
false
for
_
,
p
:=
range
existingParts
{
pm
,
ok
:=
p
.
(
map
[
string
]
any
)
if
!
ok
{
newParts
=
append
(
newParts
,
p
)
continue
}
if
_
,
hasText
:=
pm
[
"text"
];
hasText
&&
!
textUpdated
{
// Replace with merged text
newPart
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
pm
{
newPart
[
k
]
=
v
}
newPart
[
"text"
]
=
mergedText
newParts
=
append
(
newParts
,
newPart
)
textUpdated
=
true
}
else
{
newParts
=
append
(
newParts
,
pm
)
}
}
if
!
textUpdated
{
newParts
=
append
([]
any
{
map
[
string
]
any
{
"text"
:
mergedText
}},
newParts
...
)
}
content
[
"parts"
]
=
newParts
result
[
"candidates"
]
=
candidates
return
result
}
type
geminiNativeStreamResult
struct
{
usage
*
ClaudeUsage
firstTokenMs
*
int
...
...
@@ -2816,3 +3049,26 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any {
}
return
out
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
func
(
s
*
GeminiMessagesCompatService
)
extractImageSize
(
body
[]
byte
)
string
{
var
req
struct
{
GenerationConfig
*
struct
{
ImageConfig
*
struct
{
ImageSize
string
`json:"imageSize"`
}
`json:"imageConfig"`
}
`json:"generationConfig"`
}
if
err
:=
json
.
Unmarshal
(
body
,
&
req
);
err
!=
nil
{
return
"2K"
}
if
req
.
GenerationConfig
!=
nil
&&
req
.
GenerationConfig
.
ImageConfig
!=
nil
{
size
:=
strings
.
ToUpper
(
strings
.
TrimSpace
(
req
.
GenerationConfig
.
ImageConfig
.
ImageSize
))
if
size
==
"1K"
||
size
==
"2K"
||
size
==
"4K"
{
return
size
}
}
return
"2K"
}
backend/internal/service/gemini_multiplatform_test.go
View file @
2fe8932c
...
...
@@ -17,6 +17,8 @@ import (
type
mockAccountRepoForGemini
struct
{
accounts
[]
Account
accountsByID
map
[
int64
]
*
Account
listByGroupFunc
func
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
listByPlatformFunc
func
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
}
func
(
m
*
mockAccountRepoForGemini
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
...
...
@@ -88,6 +90,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda
func
(
m
*
mockAccountRepoForGemini
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ClearError
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForGemini
)
SetSchedulable
(
ctx
context
.
Context
,
id
int64
,
schedulable
bool
)
error
{
return
nil
}
...
...
@@ -104,6 +109,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context,
return
nil
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByPlatforms
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
{
if
m
.
listByPlatformFunc
!=
nil
{
return
m
.
listByPlatformFunc
(
ctx
,
platforms
)
}
var
result
[]
Account
platformSet
:=
make
(
map
[
string
]
bool
)
for
_
,
p
:=
range
platforms
{
...
...
@@ -117,6 +125,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex
return
result
,
nil
}
func
(
m
*
mockAccountRepoForGemini
)
ListSchedulableByGroupIDAndPlatforms
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
{
if
m
.
listByGroupFunc
!=
nil
{
return
m
.
listByGroupFunc
(
ctx
,
groupID
,
platforms
)
}
return
m
.
ListSchedulableByPlatforms
(
ctx
,
platforms
)
}
func
(
m
*
mockAccountRepoForGemini
)
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
{
...
...
@@ -212,6 +223,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
type
mockGatewayCacheForGemini
struct
{
sessionBindings
map
[
string
]
int64
deletedSessions
map
[
string
]
int
}
func
(
m
*
mockGatewayCacheForGemini
)
GetSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
(
int64
,
error
)
{
...
...
@@ -233,6 +245,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group
return
nil
}
func
(
m
*
mockGatewayCacheForGemini
)
DeleteSessionAccountID
(
ctx
context
.
Context
,
groupID
int64
,
sessionHash
string
)
error
{
if
m
.
sessionBindings
==
nil
{
return
nil
}
if
m
.
deletedSessions
==
nil
{
m
.
deletedSessions
=
make
(
map
[
string
]
int
)
}
m
.
deletedSessions
[
sessionHash
]
++
delete
(
m
.
sessionBindings
,
sessionHash
)
return
nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
...
...
@@ -523,6 +547,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS
// 粘性会话未命中,按优先级选择
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
,
"粘性会话未命中,应按优先级选择"
)
})
t
.
Run
(
"粘性会话不可调度-清理并回退选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusDisabled
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{
sessionBindings
:
map
[
string
]
int64
{
"gemini:session-123"
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
"session-123"
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
require
.
Equal
(
t
,
1
,
cache
.
deletedSessions
[
"gemini:session-123"
])
require
.
Equal
(
t
,
int64
(
2
),
cache
.
sessionBindings
[
"gemini:session-123"
])
})
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ForcePlatformFallback
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
groupID
:=
int64
(
9
)
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
ForcePlatform
,
PlatformAntigravity
)
repo
:=
&
mockAccountRepoForGemini
{
listByGroupFunc
:
func
(
ctx
context
.
Context
,
groupID
int64
,
platforms
[]
string
)
([]
Account
,
error
)
{
return
nil
,
nil
},
listByPlatformFunc
:
func
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
{
return
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
nil
},
accountsByID
:
map
[
int64
]
*
Account
{
1
:
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
&
groupID
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoModelSupport
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gemini-1.0-pro"
:
"gemini-1.0-pro"
}},
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"supporting model"
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyMixedScheduling
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Extra
:
map
[
string
]
any
{
"mixed_scheduling"
:
true
}},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{
sessionBindings
:
map
[
string
]
int64
{
"gemini:session-999"
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
"session-999"
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
1
),
acc
.
ID
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_SkipDisabledMixedScheduling
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAntigravity
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludedAccount
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
2
,
Status
:
StatusActive
,
Schedulable
:
true
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
excluded
:=
map
[
int64
]
struct
{}{
1
:
{}}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
excluded
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ListError
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
listByPlatformFunc
:
func
(
ctx
context
.
Context
,
platforms
[]
string
)
([]
Account
,
error
)
{
return
nil
,
errors
.
New
(
"query failed"
)
},
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-flash"
,
nil
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
acc
)
require
.
Contains
(
t
,
err
.
Error
(),
"query accounts failed"
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferOAuth
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeAPIKey
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Type
:
AccountTypeOAuth
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-pro"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
}
func
TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferLeastRecentlyUsed
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
oldTime
:=
time
.
Now
()
.
Add
(
-
2
*
time
.
Hour
)
newTime
:=
time
.
Now
()
.
Add
(
-
1
*
time
.
Hour
)
repo
:=
&
mockAccountRepoForGemini
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
&
newTime
},
{
ID
:
2
,
Platform
:
PlatformGemini
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
LastUsedAt
:
&
oldTime
},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForGemini
{}
groupRepo
:=
&
mockGroupRepoForGemini
{
groups
:
map
[
int64
]
*
Group
{}}
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
}
acc
,
err
:=
svc
.
SelectAccountForModelWithExclusions
(
ctx
,
nil
,
""
,
"gemini-2.5-pro"
,
nil
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
acc
)
require
.
Equal
(
t
,
int64
(
2
),
acc
.
ID
)
}
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
...
...
@@ -599,7 +891,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
name
:
"Gemini平台-有映射配置-只支持配置的模型"
,
account
:
&
Account
{
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gemini-
1
.5-pro"
:
"x"
}},
Credentials
:
map
[
string
]
any
{
"model_mapping"
:
map
[
string
]
any
{
"gemini-
2
.5-pro"
:
"x"
}},
},
model
:
"gemini-2.5-flash"
,
expected
:
false
,
...
...
backend/internal/service/gemini_native_signature_cleaner.go
0 → 100644
View file @
2fe8932c
package
service
import
(
"encoding/json"
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
// 以避免跨账号签名验证错误。
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
//
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
// to avoid cross-account signature validation errors.
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// By removing these signatures, we allow the new account to generate valid signatures.
func
CleanGeminiNativeThoughtSignatures
(
body
[]
byte
)
[]
byte
{
if
len
(
body
)
==
0
{
return
body
}
// 解析 JSON
var
data
any
if
err
:=
json
.
Unmarshal
(
body
,
&
data
);
err
!=
nil
{
// 如果解析失败,返回原始 body(可能不是 JSON 或格式不正确)
return
body
}
// 递归清理 thoughtSignature
cleaned
:=
cleanThoughtSignaturesRecursive
(
data
)
// 重新序列化
result
,
err
:=
json
.
Marshal
(
cleaned
)
if
err
!=
nil
{
// 如果序列化失败,返回原始 body
return
body
}
return
result
}
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
func
cleanThoughtSignaturesRecursive
(
data
any
)
any
{
switch
v
:=
data
.
(
type
)
{
case
map
[
string
]
any
:
// 创建新的 map,移除 thoughtSignature
result
:=
make
(
map
[
string
]
any
,
len
(
v
))
for
key
,
value
:=
range
v
{
// 跳过 thoughtSignature 字段
if
key
==
"thoughtSignature"
{
continue
}
// 递归处理嵌套结构
result
[
key
]
=
cleanThoughtSignaturesRecursive
(
value
)
}
return
result
case
[]
any
:
// 递归处理数组中的每个元素
result
:=
make
([]
any
,
len
(
v
))
for
i
,
item
:=
range
v
{
result
[
i
]
=
cleanThoughtSignaturesRecursive
(
item
)
}
return
result
default
:
// 基本类型(string, number, bool, null)直接返回
return
v
}
}
backend/internal/service/gemini_token_provider.go
View file @
2fe8932c
...
...
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
"time"
...
...
@@ -131,8 +132,18 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
// 3) Populate cache with TTL
.
// 3) Populate cache with TTL
(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if
p
.
tokenCache
!=
nil
{
latestAccount
,
isStale
:=
CheckTokenVersion
(
ctx
,
account
,
p
.
accountRepo
)
if
isStale
&&
latestAccount
!=
nil
{
// 版本过时,使用 DB 中的最新 token
slog
.
Debug
(
"gemini_token_version_stale_use_latest"
,
"account_id"
,
account
.
ID
)
accessToken
=
latestAccount
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found after version check"
)
}
// 不写入缓存,让下次请求重新处理
}
else
{
ttl
:=
30
*
time
.
Minute
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
...
...
@@ -147,6 +158,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
}
return
accessToken
,
nil
}
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment