Unverified Commit 2fe8932c authored by Call White's avatar Call White Committed by GitHub
Browse files

Merge pull request #3 from cyhhao/main

merge to main
parents 2f2e76f9 adb77af1
......@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrServiceUnavailable
}
// 应用优惠码(如果提供)
if promoCode != "" && s.promoService != nil {
// 应用优惠码(如果提供且功能已启用
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
// 优惠码应用失败不影响注册,只记录日志
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
......@@ -580,3 +580,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 生成新token
return s.GenerateToken(user)
}
// IsPasswordResetEnabled 检查是否启用密码重置功能
// 要求:必须同时开启邮件验证且 SMTP 配置正确
func (s *AuthService) IsPasswordResetEnabled(ctx context.Context) bool {
if s.settingService == nil {
return false
}
// Must have email verification enabled and SMTP configured
if !s.settingService.IsEmailVerifyEnabled(ctx) {
return false
}
return s.settingService.IsPasswordResetEnabled(ctx)
}
// preparePasswordReset validates the password reset request and returns necessary data
// Returns (siteName, resetURL, shouldProceed)
// shouldProceed is false when we should silently return success (to prevent enumeration)
func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendBaseURL string) (string, string, bool) {
// Check if user exists (but don't reveal this to the caller)
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// Security: Log but don't reveal that user doesn't exist
log.Printf("[Auth] Password reset requested for non-existent email: %s", email)
return "", "", false
}
log.Printf("[Auth] Database error checking email for password reset: %v", err)
return "", "", false
}
// Check if user is active
if !user.IsActive() {
log.Printf("[Auth] Password reset requested for inactive user: %s", email)
return "", "", false
}
// Get site name
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
// Build reset URL base
resetURL := fmt.Sprintf("%s/reset-password", strings.TrimSuffix(frontendBaseURL, "/"))
return siteName, resetURL, true
}
// RequestPasswordReset 请求密码重置(同步发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error {
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailService == nil {
return ErrServiceUnavailable
}
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
if !shouldProceed {
return nil // Silent success to prevent enumeration
}
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err)
return nil // Silent success to prevent enumeration
}
log.Printf("[Auth] Password reset email sent to: %s", email)
return nil
}
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error {
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailQueueService == nil {
return ErrServiceUnavailable
}
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
if !shouldProceed {
return nil // Silent success to prevent enumeration
}
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err)
return nil // Silent success to prevent enumeration
}
log.Printf("[Auth] Password reset email enqueued for: %s", email)
return nil
}
// ResetPassword 重置密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPassword string) error {
// Check if password reset is enabled
if !s.IsPasswordResetEnabled(ctx) {
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
}
if s.emailService == nil {
return ErrServiceUnavailable
}
// Verify and consume the reset token (one-time use)
if err := s.emailService.ConsumePasswordResetToken(ctx, email, token); err != nil {
return err
}
// Get user
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return ErrInvalidResetToken // Token was valid but user was deleted
}
log.Printf("[Auth] Database error getting user for password reset: %v", err)
return ErrServiceUnavailable
}
// Check if user is active
if !user.IsActive() {
return ErrUserNotActive
}
// Hash new password
hashedPassword, err := s.HashPassword(newPassword)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
// Update password and increment TokenVersion
user.PasswordHash = hashedPassword
user.TokenVersion++ // Invalidate all existing tokens
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err)
return ErrServiceUnavailable
}
log.Printf("[Auth] Password reset successful for user: %s", email)
return nil
}
......@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return nil
}
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
return nil, nil
}
func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error {
return nil
}
func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error {
return nil
}
func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
return false
}
func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
return nil
}
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
......
......@@ -181,8 +181,18 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
......@@ -203,6 +213,7 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
}
return accessToken, nil
}
......@@ -21,11 +21,15 @@ var (
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
errDashboardAggregationRunning = errors.New("聚合作业正在运行")
)
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
type DashboardAggregationRepository interface {
AggregateRange(ctx context.Context, start, end time.Time) error
// RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。
// 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。
RecomputeRange(ctx context.Context, start, end time.Time) error
GetAggregationWatermark(ctx context.Context) (time.Time, error)
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
......@@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
return nil
}
// TriggerRecomputeRange 触发指定范围的重新计算(异步)。
// 与 TriggerBackfill 不同:
// - 不依赖 backfill_enabled(这是内部一致性修复)
// - 不更新 watermark(避免影响正常增量聚合游标)
func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time) error {
if s == nil || s.repo == nil {
return errors.New("聚合服务未初始化")
}
if !s.cfg.Enabled {
return errors.New("聚合服务已禁用")
}
if !end.After(start) {
return errors.New("重新计算时间范围无效")
}
go func() {
const maxRetries = 3
for i := 0; i < maxRetries; i++ {
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
err := s.recomputeRange(ctx, start, end)
cancel()
if err == nil {
return
}
if !errors.Is(err, errDashboardAggregationRunning) {
log.Printf("[DashboardAggregation] 重新计算失败: %v", err)
return
}
time.Sleep(5 * time.Second)
}
log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
}()
return nil
}
func (s *DashboardAggregationService) recomputeRecentDays() {
days := s.cfg.RecomputeDays
if days <= 0 {
......@@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
}
}
func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errDashboardAggregationRunning
}
defer atomic.StoreInt32(&s.running, 0)
jobStart := time.Now().UTC()
if err := s.repo.RecomputeRange(ctx, start, end); err != nil {
return err
}
log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
start.UTC().Format(time.RFC3339),
end.UTC().Format(time.RFC3339),
time.Since(jobStart).String(),
)
return nil
}
func (s *DashboardAggregationService) runScheduledAggregation() {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return
......@@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errors.New("聚合作业正在运行")
return errDashboardAggregationRunning
}
defer atomic.StoreInt32(&s.running, 0)
......
......@@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
return s.aggregateErr
}
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
return s.AggregateRange(ctx, start, end)
}
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
return s.watermark, nil
}
......
......@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil
}
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}
......
......@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start
return nil
}
func (s *dashboardAggregationRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
if s.err != nil {
return time.Time{}, s.err
......
......@@ -71,6 +71,8 @@ const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
......@@ -86,6 +88,9 @@ const (
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// TOTP 双因素认证设置
SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能
// LinuxDo Connect OAuth 登录设置
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
......@@ -100,6 +105,9 @@ const (
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示“购买订阅”页面入口
SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src)
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
......
......@@ -8,11 +8,18 @@ import (
"time"
)
// Task type constants
const (
TaskTypeVerifyCode = "verify_code"
TaskTypePasswordReset = "password_reset"
)
// EmailTask 邮件发送任务
type EmailTask struct {
Email string
SiteName string
TaskType string // "verify_code"
TaskType string // "verify_code" or "password_reset"
ResetURL string // Only used for password_reset task type
}
// EmailQueueService 异步邮件队列服务
......@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
defer cancel()
switch task.TaskType {
case "verify_code":
case TaskTypeVerifyCode:
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
}
case TaskTypePasswordReset:
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
}
default:
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
}
......@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: "verify_code",
TaskType: TaskTypeVerifyCode,
}
select {
......@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
}
}
// EnqueuePasswordReset 将密码重置邮件任务加入队列
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: TaskTypePasswordReset,
ResetURL: resetURL,
}
select {
case s.taskChan <- task:
log.Printf("[EmailQueue] Enqueued password reset task for %s", email)
return nil
default:
return fmt.Errorf("email queue is full")
}
}
// Stop 停止队列服务
func (s *EmailQueueService) Stop() {
close(s.stopChan)
......
......@@ -3,11 +3,14 @@ package service
import (
"context"
"crypto/rand"
"crypto/subtle"
"crypto/tls"
"encoding/hex"
"fmt"
"log"
"math/big"
"net/smtp"
"net/url"
"strconv"
"time"
......@@ -19,6 +22,9 @@ var (
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
// Password reset errors
ErrInvalidResetToken = infraerrors.BadRequest("INVALID_RESET_TOKEN", "invalid or expired password reset token")
)
// EmailCache defines cache operations for email service
......@@ -26,6 +32,16 @@ type EmailCache interface {
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error
// Password reset token methods
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
DeletePasswordResetToken(ctx context.Context, email string) error
// Password reset email cooldown methods
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
}
// VerificationCodeData represents verification code data
......@@ -35,10 +51,22 @@ type VerificationCodeData struct {
CreatedAt time.Time
}
// PasswordResetTokenData represents password reset token data
type PasswordResetTokenData struct {
Token string
CreatedAt time.Time
}
const (
verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute
maxVerifyCodeAttempts = 5
// Password reset token settings
passwordResetTokenTTL = 30 * time.Minute
// Password reset email cooldown (prevent email bombing)
passwordResetEmailCooldown = 30 * time.Second
)
// SMTPConfig SMTP配置
......@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
return ErrVerifyCodeMaxAttempts
}
// 验证码不匹配
if data.Code != code {
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err)
......@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
return client.Quit()
}
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
func (s *EmailService) GeneratePasswordResetToken() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// SendPasswordResetEmail sends a password reset email with a reset link
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error {
var token string
var needSaveToken bool
// Check if token already exists
existing, err := s.cache.GetPasswordResetToken(ctx, email)
if err == nil && existing != nil {
// Token exists, reuse it (allows resending email without generating new token)
token = existing.Token
needSaveToken = false
} else {
// Generate new token
token, err = s.GeneratePasswordResetToken()
if err != nil {
return fmt.Errorf("generate token: %w", err)
}
needSaveToken = true
}
// Save token to Redis (only if new token generated)
if needSaveToken {
data := &PasswordResetTokenData{
Token: token,
CreatedAt: time.Now(),
}
if err := s.cache.SetPasswordResetToken(ctx, email, data, passwordResetTokenTTL); err != nil {
return fmt.Errorf("save reset token: %w", err)
}
}
// Build full reset URL with URL-encoded token and email
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
// Build email content
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
// Send email
if err := s.SendEmail(ctx, email, subject, body); err != nil {
return fmt.Errorf("send email: %w", err)
}
return nil
}
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
// Check email cooldown to prevent email bombing
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
return nil // Silent success to prevent revealing cooldown to attackers
}
// Send email using core method
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
return err
}
// Set cooldown marker (Redis TTL handles expiration)
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
}
return nil
}
// VerifyPasswordResetToken verifies the password reset token without consuming it
func (s *EmailService) VerifyPasswordResetToken(ctx context.Context, email, token string) error {
data, err := s.cache.GetPasswordResetToken(ctx, email)
if err != nil || data == nil {
return ErrInvalidResetToken
}
// Use constant-time comparison to prevent timing attacks
if subtle.ConstantTimeCompare([]byte(data.Token), []byte(token)) != 1 {
return ErrInvalidResetToken
}
return nil
}
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, token string) error {
// Verify first
if err := s.VerifyPasswordResetToken(ctx, email, token); err != nil {
return err
}
// Delete after verification (one-time use)
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
}
return nil
}
// buildPasswordResetEmailBody builds the HTML content for password reset email
func (s *EmailService) buildPasswordResetEmailBody(resetURL, siteName string) string {
return fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
.button:hover { opacity: 0.9; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
.warning { color: #e74c3c; font-weight: 500; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">密码重置请求</p>
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
<a href="%s" class="button">重置密码</a>
<div class="info">
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
</div>
<div class="link-fallback">
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
<p>%s</p>
</div>
</div>
<div class="footer">
<p>这是一封自动发送的邮件,请勿回复。</p>
</div>
</div>
</body>
</html>
`, siteName, resetURL, resetURL)
}
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMergeAnthropicBeta(t *testing.T) {
got := mergeAnthropicBeta(
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
"foo, oauth-2025-04-20,bar, foo",
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar", got)
}
func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
got := mergeAnthropicBeta(
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
"",
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
}
......@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up
func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearError(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
......@@ -179,6 +182,7 @@ var _ AccountRepository = (*mockAccountRepoForPlatform)(nil)
// mockGatewayCacheForPlatform 单平台测试用的 cache mock
type mockGatewayCacheForPlatform struct {
sessionBindings map[string]int64
deletedSessions map[string]int
}
func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
......@@ -200,6 +204,18 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
return nil
}
func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
if m.sessionBindings == nil {
return nil
}
if m.deletedSessions == nil {
m.deletedSessions = make(map[string]int)
}
m.deletedSessions[sessionHash]++
delete(m.sessionBindings, sessionHash)
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int
......@@ -623,76 +639,96 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi
})
}
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
svc := &GatewayService{}
func TestGatewayService_SelectAccountForModelWithExclusions_ForcePlatform(t *testing.T) {
ctx := context.Background()
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity)
tests := []struct {
name string
account *Account
model string
expected bool
}{
{
name: "Antigravity平台-支持claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "Antigravity平台-支持gemini模型",
account: &Account{Platform: PlatformAntigravity},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "Antigravity平台-不支持gpt模型",
account: &Account{Platform: PlatformAntigravity},
model: "gpt-4",
expected: false,
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true},
},
{
name: "Anthropic平台-无映射配置-支持所有模型",
account: &Account{Platform: PlatformAnthropic},
model: "claude-3-5-sonnet-20241022",
expected: true,
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
require.Equal(t, PlatformAntigravity, acc.Platform)
}
func TestGatewayService_SelectAccountForModelWithPlatform_RoutedStickySessionClears(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
requestedModel := "claude-3-5-sonnet-20241022"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusDisabled, Schedulable: true},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
{
name: "Anthropic平台-有映射配置-只支持配置的模型",
account: &Account{
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 1},
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Name: "route-group",
Platform: PlatformAnthropic,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-opus-4": "x"}},
},
model: "claude-3-5-sonnet-20241022",
expected: false,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
requestedModel: {1, 2},
},
{
name: "Anthropic平台-有映射配置-支持配置的模型",
account: &Account{
Platform: PlatformAnthropic,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-sonnet-20241022": "x"}},
},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.isModelSupportedByAccount(tt.account, tt.model)
require.Equal(t, tt.expected, got)
})
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
groupRepo: groupRepo,
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "session-123", requestedModel, nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
require.Equal(t, 1, cache.deletedSessions["session-123"])
require.Equal(t, int64(2), cache.sessionBindings["session-123"])
}
// TestGatewayService_selectAccountWithMixedScheduling 测试混合调度
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
func TestGatewayService_SelectAccountForModelWithPlatform_RoutedStickySessionHit(t *testing.T) {
ctx := context.Background()
groupID := int64(11)
requestedModel := "claude-3-5-sonnet-20241022"
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
......@@ -700,25 +736,48 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-456": 1},
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Name: "route-group-hit",
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
requestedModel: {1, 2},
},
},
},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
groupRepo: groupRepo,
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "session-456", requestedModel, nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
})
require.Equal(t, int64(1), acc.ID)
}
func TestGatewayService_SelectAccountForModelWithPlatform_RoutedFallbackToNormal(t *testing.T) {
ctx := context.Background()
groupID := int64(12)
requestedModel := "claude-3-5-sonnet-20241022"
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
......@@ -728,23 +787,48 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cache := &mockGatewayCacheForPlatform{}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Name: "route-fallback",
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
requestedModel: {99},
},
},
},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
groupRepo: groupRepo,
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", requestedModel, nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选择优先级最高的账户(包含启用混合调度的antigravity)")
})
require.Equal(t, int64(1), acc.ID)
}
func TestGatewayService_SelectAccountForModelWithPlatform_NoModelSupport(t *testing.T) {
ctx := context.Background()
t.Run("混合调度-过滤未启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
{
ID: 1,
Platform: PlatformAnthropic,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}},
},
},
accountsByID: map[int64]*Account{},
}
......@@ -760,18 +844,19 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "未启用mixed_scheduling的antigravity账户应被过滤")
require.Equal(t, PlatformAnthropic, acc.Platform)
})
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "supporting model")
}
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth(t *testing.T) {
ctx := context.Background()
t.Run("混合调度-粘性会话命中启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
......@@ -779,9 +864,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 2},
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
......@@ -789,17 +872,20 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应返回粘性会话绑定的启用mixed_scheduling的antigravity账户")
})
require.Equal(t, int64(2), acc.ID)
}
func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(50)
t.Run("混合调度-粘性会话命中未启用mixed_scheduling的antigravity账户-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, AccountGroups: []AccountGroup{{GroupID: groupID}}},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, AccountGroups: []AccountGroup{{GroupID: groupID}}},
},
accountsByID: map[int64]*Account{},
}
......@@ -808,7 +894,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 2},
sessionBindings: map[string]int64{"session-group": 1},
}
svc := &GatewayService{
......@@ -817,16 +903,26 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "session-group", "", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "粘性会话绑定的账户未启用mixed_scheduling,应降级选择anthropic账户")
})
require.Equal(t, int64(1), acc.ID)
}
func TestGatewayService_SelectAccountForModelWithPlatform_StickyModelMismatchFallback(t *testing.T) {
ctx := context.Background()
t.Run("混合调度-仅有启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
{
ID: 1,
Platform: PlatformAnthropic,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}},
},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
......@@ -834,7 +930,9 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-miss": 1},
}
svc := &GatewayService{
accountRepo: repo,
......@@ -842,17 +940,20 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-miss", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
require.Equal(t, PlatformAntigravity, acc.Platform)
})
require.Equal(t, int64(2), acc.ID)
}
func TestGatewayService_SelectAccountForModelWithPlatform_PreferNeverUsed(t *testing.T) {
ctx := context.Background()
lastUsed := time.Now().Add(-1 * time.Hour)
t.Run("混合调度-无可用账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &lastUsed},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
......@@ -868,171 +969,1505 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available accounts")
})
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
// TestAccount_IsMixedSchedulingEnabled 测试混合调度开关检查
func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
tests := []struct {
name string
account Account
func TestGatewayService_SelectAccountForModelWithPlatform_NoAccounts(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformAnthropic)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available accounts")
}
func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
svc := &GatewayService{}
tests := []struct {
name string
account *Account
model string
expected bool
}{
{
name: "非antigravity平台-返回false",
account: Account{Platform: PlatformAnthropic},
expected: false,
},
{
name: "antigravity平台-无extra-返回false",
account: Account{Platform: PlatformAntigravity},
expected: false,
name: "Antigravity平台-支持claude模型",
account: &Account{Platform: PlatformAntigravity},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "antigravity平台-extra无mixed_scheduling-返回false",
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{}},
expected: false,
name: "Antigravity平台-支持gemini模型",
account: &Account{Platform: PlatformAntigravity},
model: "gemini-2.5-flash",
expected: true,
},
{
name: "antigravity平台-mixed_scheduling=false-返回false",
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{"mixed_scheduling": false}},
name: "Antigravity平台-不支持gpt模型",
account: &Account{Platform: PlatformAntigravity},
model: "gpt-4",
expected: false,
},
{
name: "antigravity平台-mixed_scheduling=true-返回true",
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{"mixed_scheduling": true}},
name: "Anthropic平台-无映射配置-支持所有模型",
account: &Account{Platform: PlatformAnthropic},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
{
name: "antigravity平台-mixed_scheduling非bool类型-返回false",
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{"mixed_scheduling": "true"}},
name: "Anthropic平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformAnthropic,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-opus-4": "x"}},
},
model: "claude-3-5-sonnet-20241022",
expected: false,
},
{
name: "Anthropic平台-有映射配置-支持配置的模型",
account: &Account{
Platform: PlatformAnthropic,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-sonnet-20241022": "x"}},
},
model: "claude-3-5-sonnet-20241022",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.IsMixedSchedulingEnabled()
got := svc.isModelSupportedByAccount(tt.account, tt.model)
require.Equal(t, tt.expected, got)
})
}
}
// mockConcurrencyService for testing
type mockConcurrencyService struct {
accountLoads map[int64]*AccountLoadInfo
accountWaitCounts map[int64]int
acquireResults map[int64]bool
}
// TestGatewayService_selectAccountWithMixedScheduling 测试混合调度
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
ctx := context.Background()
func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if m.accountLoads == nil {
return map[int64]*AccountLoadInfo{}, nil
}
result := make(map[int64]*AccountLoadInfo)
for _, acc := range accounts {
if load, ok := m.accountLoads[acc.ID]; ok {
result[acc.ID] = load
} else {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
return result, nil
}
func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if m.accountWaitCounts == nil {
return 0, nil
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
return m.accountWaitCounts[accountID], nil
}
type mockConcurrencyCache struct {
acquireAccountCalls int
loadBatchCalls int
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
})
func (m *mockConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
m.acquireAccountCalls++
return true, nil
}
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
func (m *mockConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
cache := &mockGatewayCacheForPlatform{}
func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应选择优先级最高的账户(包含启用混合调度的antigravity)")
})
func (m *mockConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
t.Run("混合调度-路由优先选择路由账号", func(t *testing.T) {
groupID := int64(30)
requestedModel := "claude-3-5-sonnet-20241022"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
func (m *mockConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
cache := &mockGatewayCacheForPlatform{}
func (m *mockConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Name: "route-mixed-select",
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
requestedModel: {2},
},
},
},
}
func (m *mockConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
return nil
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
groupRepo: groupRepo,
}
func (m *mockConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "", requestedModel, nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
})
func (m *mockConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
t.Run("混合调度-路由粘性命中", func(t *testing.T) {
groupID := int64(31)
requestedModel := "claude-3-5-sonnet-20241022"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}, AccountGroups: []AccountGroup{{GroupID: groupID}}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
func (m *mockConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-777": 2},
}
func (m *mockConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
m.loadBatchCalls++
result := make(map[int64]*AccountLoadInfo, len(accounts))
for _, acc := range accounts {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Name: "route-mixed-sticky",
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
requestedModel: {2},
},
},
},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
groupRepo: groupRepo,
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "session-777", requestedModel, nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
})
t.Run("混合调度-路由账号缺失回退", func(t *testing.T) {
groupID := int64(32)
requestedModel := "claude-3-5-sonnet-20241022"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Name: "route-mixed-miss",
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
requestedModel: {99},
},
},
},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
groupRepo: groupRepo,
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "", requestedModel, nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
})
t.Run("混合调度-路由账号未启用mixed_scheduling回退", func(t *testing.T) {
groupID := int64(33)
requestedModel := "claude-3-5-sonnet-20241022"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Name: "route-mixed-disabled",
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
requestedModel: {2},
},
},
},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
groupRepo: groupRepo,
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "", requestedModel, nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
})
t.Run("混合调度-路由过滤覆盖", func(t *testing.T) {
groupID := int64(35)
requestedModel := "claude-3-5-sonnet-20241022"
resetAt := time.Now().Add(10 * time.Minute)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false},
{ID: 3, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
{
ID: 4,
Platform: PlatformAnthropic,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"model_rate_limits": map[string]any{
"claude_sonnet": map[string]any{
"rate_limit_reset_at": resetAt.Format(time.RFC3339),
},
},
},
},
{
ID: 5,
Platform: PlatformAnthropic,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}},
},
{ID: 6, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 7, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Name: "route-mixed-filter",
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
requestedModel: {1, 2, 3, 4, 5, 6, 7},
},
},
},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
groupRepo: groupRepo,
}
excluded := map[int64]struct{}{1: {}}
acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "", requestedModel, excluded, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(7), acc.ID)
})
t.Run("混合调度-粘性命中分组账号", func(t *testing.T) {
groupID := int64(34)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, AccountGroups: []AccountGroup{{GroupID: groupID}}},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, AccountGroups: []AccountGroup{{GroupID: groupID}}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-group": 1},
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
},
},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
groupRepo: groupRepo,
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "session-group", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
})
t.Run("混合调度-过滤未启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "未启用mixed_scheduling的antigravity账户应被过滤")
require.Equal(t, PlatformAnthropic, acc.Platform)
})
t.Run("混合调度-粘性会话命中启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 2},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "应返回粘性会话绑定的启用mixed_scheduling的antigravity账户")
})
t.Run("混合调度-粘性会话命中未启用mixed_scheduling的antigravity账户-降级选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 2},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID, "粘性会话绑定的账户未启用mixed_scheduling,应降级选择anthropic账户")
})
t.Run("混合调度-粘性会话不可调度-清理并回退", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusDisabled, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 1},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
require.Equal(t, 1, cache.deletedSessions["session-123"])
require.Equal(t, int64(2), cache.sessionBindings["session-123"])
})
t.Run("混合调度-路由粘性不可调度-清理并回退", func(t *testing.T) {
groupID := int64(12)
requestedModel := "claude-3-5-sonnet-20241022"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusDisabled, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"session-123": 1},
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Name: "route-mixed",
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
requestedModel: {1, 2},
},
},
},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
groupRepo: groupRepo,
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, &groupID, "session-123", requestedModel, nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
require.Equal(t, 1, cache.deletedSessions["session-123"])
require.Equal(t, int64(2), cache.sessionBindings["session-123"])
})
t.Run("混合调度-仅有启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
require.Equal(t, PlatformAntigravity, acc.Platform)
})
t.Run("混合调度-无可用账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 未启用 mixed_scheduling
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "no available accounts")
})
t.Run("混合调度-不支持模型返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{
ID: 1,
Platform: PlatformAnthropic,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}},
},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "supporting model")
})
t.Run("混合调度-优先未使用账号", func(t *testing.T) {
lastUsed := time.Now().Add(-2 * time.Hour)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &lastUsed},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
})
}
// TestAccount_IsMixedSchedulingEnabled 测试混合调度开关检查
func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
tests := []struct {
name string
account Account
expected bool
}{
{
name: "非antigravity平台-返回false",
account: Account{Platform: PlatformAnthropic},
expected: false,
},
{
name: "antigravity平台-无extra-返回false",
account: Account{Platform: PlatformAntigravity},
expected: false,
},
{
name: "antigravity平台-extra无mixed_scheduling-返回false",
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{}},
expected: false,
},
{
name: "antigravity平台-mixed_scheduling=false-返回false",
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{"mixed_scheduling": false}},
expected: false,
},
{
name: "antigravity平台-mixed_scheduling=true-返回true",
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{"mixed_scheduling": true}},
expected: true,
},
{
name: "antigravity平台-mixed_scheduling非bool类型-返回false",
account: Account{Platform: PlatformAntigravity, Extra: map[string]any{"mixed_scheduling": "true"}},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.account.IsMixedSchedulingEnabled()
require.Equal(t, tt.expected, got)
})
}
}
// mockConcurrencyService for testing
type mockConcurrencyService struct {
accountLoads map[int64]*AccountLoadInfo
accountWaitCounts map[int64]int
acquireResults map[int64]bool
}
func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if m.accountLoads == nil {
return map[int64]*AccountLoadInfo{}, nil
}
result := make(map[int64]*AccountLoadInfo)
for _, acc := range accounts {
if load, ok := m.accountLoads[acc.ID]; ok {
result[acc.ID] = load
} else {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
}
return result, nil
}
func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if m.accountWaitCounts == nil {
return 0, nil
}
return m.accountWaitCounts[accountID], nil
}
type mockConcurrencyCache struct {
acquireAccountCalls int
loadBatchCalls int
acquireResults map[int64]bool
loadBatchErr error
loadMap map[int64]*AccountLoadInfo
waitCounts map[int64]int
skipDefaultLoad bool
}
func (m *mockConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
m.acquireAccountCalls++
if m.acquireResults != nil {
if result, ok := m.acquireResults[accountID]; ok {
return result, nil
}
}
return true, nil
}
func (m *mockConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *mockConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (m *mockConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if m.waitCounts != nil {
if count, ok := m.waitCounts[accountID]; ok {
return count, nil
}
}
return 0, nil
}
func (m *mockConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
func (m *mockConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
return nil
}
func (m *mockConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (m *mockConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *mockConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (m *mockConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
m.loadBatchCalls++
if m.loadBatchErr != nil {
return nil, m.loadBatchErr
}
result := make(map[int64]*AccountLoadInfo, len(accounts))
if m.skipDefaultLoad && m.loadMap != nil {
for _, acc := range accounts {
if load, ok := m.loadMap[acc.ID]; ok {
result[acc.ID] = load
}
}
return result, nil
}
for _, acc := range accounts {
if m.loadMap != nil {
if load, ok := m.loadMap[acc.ID]; ok {
result[acc.ID] = load
continue
}
}
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
return result, nil
}
func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // No concurrency service
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("模型路由-无ConcurrencyService也生效", func(t *testing.T) {
groupID := int64(1)
sessionHash := "sticky"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{sessionHash: 1},
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-a": {1},
"claude-b": {2},
},
},
},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // legacy path
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "切换到 claude-b 时应按模型路由切换账号")
require.Equal(t, int64(2), cache.sessionBindings[sessionHash], "粘性绑定应更新为路由选择的账号")
})
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("排除账号-不选择被排除的账号", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
})
t.Run("粘性命中-不调用GetByID", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"sticky": 1},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID)
require.Equal(t, 0, repo.getByIDCalls, "粘性命中不应调用GetByID")
require.Equal(t, 0, concurrencyCache.loadBatchCalls, "粘性命中应在负载批量查询前返回")
})
t.Run("粘性账号不在候选集-回退负载感知选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"sticky": 1},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "粘性账号不在候选集时应回退到可用账号")
require.Equal(t, 0, repo.getByIDCalls, "粘性账号缺失不应回退到GetByID")
require.Equal(t, 1, concurrencyCache.loadBatchCalls, "应继续进行负载批量查询")
})
t.Run("粘性账号禁用-清理会话并回退选择", func(t *testing.T) {
testCtx := context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAnthropic)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
repo.listPlatformFunc = func(ctx context.Context, platform string) ([]Account, error) {
return repo.accounts, nil
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"sticky": 1},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(testCtx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "粘性账号禁用时应回退到可用账号")
updatedID, ok := cache.sessionBindings["sticky"]
require.True(t, ok, "粘性会话应更新绑定")
require.Equal(t, int64(2), updatedID, "粘性会话应绑定到新账号")
})
t.Run("无可用账号-返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
})
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, RateLimitResetAt: &resetAt},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过限流账号,选择可用账号")
})
t.Run("过滤不可调度账号-过载账号被跳过", func(t *testing.T) {
now := time.Now()
overloadUntil := now.Add(10 * time.Minute)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, OverloadUntil: &overloadUntil},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过过载账号,选择可用账号")
})
t.Run("粘性账号槽位满-返回粘性等待计划", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"sticky": 1},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
cfg.Gateway.Scheduling.StickySessionMaxWaiting = 1
concurrencyCache := &mockConcurrencyCache{
acquireResults: map[int64]bool{1: false},
waitCounts: map[int64]int{1: 0},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.WaitPlan)
require.Equal(t, int64(1), result.Account.ID)
require.Equal(t, 0, concurrencyCache.loadBatchCalls)
})
t.Run("负载批量查询失败-降级旧顺序选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{
loadBatchErr: errors.New("load batch failed"),
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "legacy", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID)
require.Equal(t, int64(2), cache.sessionBindings["legacy"])
})
t.Run("模型路由-粘性账号等待计划", func(t *testing.T) {
groupID := int64(20)
sessionHash := "route-sticky"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{sessionHash: 1},
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-3-5-sonnet-20241022": {1, 2},
},
},
},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
cfg.Gateway.Scheduling.StickySessionMaxWaiting = 1
concurrencyCache := &mockConcurrencyCache{
acquireResults: map[int64]bool{1: false},
waitCounts: map[int64]int{1: 0},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.WaitPlan)
require.Equal(t, int64(1), result.Account.ID)
})
t.Run("模型路由-粘性账号命中", func(t *testing.T) {
groupID := int64(20)
sessionHash := "route-hit"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{sessionHash: 1},
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-3-5-sonnet-20241022": {1, 2},
},
},
},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID)
require.Equal(t, 0, concurrencyCache.loadBatchCalls)
})
t.Run("模型路由-粘性账号缺失-清理并回退", func(t *testing.T) {
groupID := int64(22)
sessionHash := "route-missing"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{sessionHash: 1},
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-3-5-sonnet-20241022": {1, 2},
},
},
},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
return result, nil
}
func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID)
require.Equal(t, 1, cache.deletedSessions[sessionHash])
require.Equal(t, int64(2), cache.sessionBindings[sessionHash])
})
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
t.Run("模型路由-按负载选择账号", func(t *testing.T) {
groupID := int64(21)
t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
......@@ -1042,31 +2477,54 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
cache := &mockGatewayCacheForPlatform{}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-3-5-sonnet-20241022": {1, 2},
},
},
},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 80},
2: {AccountID: 2, LoadRate: 20},
},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // No concurrency service
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
require.Equal(t, int64(2), result.Account.ID)
require.Equal(t, int64(2), cache.sessionBindings["route"])
})
t.Run("模型路由-无ConcurrencyService也生效", func(t *testing.T) {
groupID := int64(1)
sessionHash := "sticky"
t.Run("模型路由-路由账号全满返回等待计划", func(t *testing.T) {
groupID := int64(23)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
......@@ -1074,9 +2532,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{sessionHash: 1},
}
cache := &mockGatewayCacheForPlatform{}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
......@@ -1087,8 +2543,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-a": {1},
"claude-b": {2},
"claude-3-5-sonnet-20241022": {1, 2},
},
},
},
......@@ -1097,27 +2552,37 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{
acquireResults: map[int64]bool{1: false, 2: false},
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 10},
2: {AccountID: 2, LoadRate: 20},
},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // legacy path
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "route-full", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "切换到 claude-b 时应按模型路由切换账号")
require.Equal(t, int64(2), cache.sessionBindings[sessionHash], "粘性绑定应更新为路由选择的账号")
require.NotNil(t, result.WaitPlan)
require.Equal(t, int64(1), result.Account.ID)
})
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
t.Run("模型路由-路由账号全满-回退普通选择", func(t *testing.T) {
groupID := int64(22)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 3, Platform: PlatformAnthropic, Priority: 0, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
......@@ -1127,24 +2592,49 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
cache := &mockGatewayCacheForPlatform{}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-3-5-sonnet-20241022": {1, 2},
},
},
},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 100},
2: {AccountID: 2, LoadRate: 100},
3: {AccountID: 3, LoadRate: 0},
},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "fallback", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号")
require.Equal(t, int64(3), result.Account.ID)
require.Equal(t, int64(3), cache.sessionBindings["fallback"])
})
t.Run("排除账号-不选择被排除的账号", func(t *testing.T) {
t.Run("负载批量失败且无法获取-兜底等待", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
......@@ -1159,27 +2649,34 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{
loadBatchErr: errors.New("load batch failed"),
acquireResults: map[int64]bool{1: false, 2: false},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
require.NotNil(t, result.WaitPlan)
require.Equal(t, int64(1), result.Account.ID)
})
t.Run("粘性命中-不调用GetByID", func(t *testing.T) {
t.Run("Gemini负载排序-优先OAuth", func(t *testing.T) {
groupID := int64(24)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
......@@ -1187,35 +2684,77 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"sticky": 1},
cache := &mockGatewayCacheForPlatform{}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformGemini,
Status: StatusActive,
Hydrated: true,
},
},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{}
concurrencyCache := &mockConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 10},
2: {AccountID: 2, LoadRate: 10},
},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "gemini", "gemini-2.5-pro", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID)
require.Equal(t, 0, repo.getByIDCalls, "粘性命中不应调用GetByID")
require.Equal(t, 0, concurrencyCache.loadBatchCalls, "粘性命中应在负载批量查询前返回")
require.Equal(t, int64(2), result.Account.ID)
})
t.Run("粘性账号不在候选集-回退负载感知选择", func(t *testing.T) {
t.Run("模型路由-过滤路径覆盖", func(t *testing.T) {
groupID := int64(70)
now := time.Now().Add(10 * time.Minute)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 3, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: false, Concurrency: 5},
{ID: 4, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{
ID: 5,
Platform: PlatformAnthropic,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Concurrency: 5,
Extra: map[string]any{
"model_rate_limits": map[string]any{
"claude_sonnet": map[string]any{
"rate_limit_reset_at": now.Format(time.RFC3339),
},
},
},
},
{
ID: 6,
Platform: PlatformAnthropic,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Concurrency: 5,
Credentials: map[string]any{"model_mapping": map[string]any{"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022"}},
},
{ID: 7, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
......@@ -1223,8 +2762,21 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{"sticky": 1},
cache := &mockGatewayCacheForPlatform{}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-3-5-sonnet-20241022": {1, 2, 3, 4, 5, 6},
},
},
},
}
cfg := testConfig()
......@@ -1234,51 +2786,110 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
excluded := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", excluded, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "粘性账号不在候选集时应回退到可用账号")
require.Equal(t, 0, repo.getByIDCalls, "粘性账号缺失不应回退到GetByID")
require.Equal(t, 1, concurrencyCache.loadBatchCalls, "应继续进行负载批量查询")
require.Equal(t, int64(7), result.Account.ID)
})
t.Run("无可用账号-返回错误", func(t *testing.T) {
t.Run("ClaudeCode限制-回退分组", func(t *testing.T) {
groupID := int64(60)
fallbackID := int64(61)
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ClaudeCodeOnly: true,
FallbackGroupID: func() *int64 {
v := fallbackID
return &v
}(),
},
fallbackID: {
ID: fallbackID,
Platform: PlatformGemini,
Status: StatusActive,
Hydrated: true,
},
},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
groupRepo: groupRepo,
cache: &mockGatewayCacheForPlatform{},
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "gemini-2.5-pro", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID)
})
t.Run("ClaudeCode限制-无降级返回错误", func(t *testing.T) {
groupID := int64(62)
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ClaudeCodeOnly: true,
},
},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: &mockAccountRepoForPlatform{},
groupRepo: groupRepo,
cache: &mockGatewayCacheForPlatform{},
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil, "")
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
require.ErrorIs(t, err, ErrClaudeCodeOnly)
})
t.Run("过滤不可调度账号-限流账号被跳过", func(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
t.Run("负载可用但无法获取槽位-兜底等待", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, RateLimitResetAt: &resetAt},
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
......@@ -1288,31 +2899,37 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{
acquireResults: map[int64]bool{1: false, 2: false},
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 10},
2: {AccountID: 2, LoadRate: 20},
},
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "wait", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过限流账号,选择可用账号")
require.NotNil(t, result.WaitPlan)
require.Equal(t, int64(1), result.Account.ID)
})
t.Run("过滤不可调度账号-过载账号被跳过", func(t *testing.T) {
now := time.Now()
overloadUntil := now.Add(10 * time.Minute)
t.Run("负载信息缺失-使用默认负载", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, OverloadUntil: &overloadUntil},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
......@@ -1321,21 +2938,29 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cfg.Gateway.Scheduling.LoadBatchEnabled = true
concurrencyCache := &mockConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
1: {AccountID: 1, LoadRate: 50},
},
skipDefaultLoad: true,
}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
concurrencyService: NewConcurrencyService(concurrencyCache),
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "missing-load", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应跳过过载账号,选择可用账号")
require.Equal(t, int64(2), result.Account.ID)
})
}
......
package service
import (
"regexp"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Stream: true,
MetadataUserID: "",
System: nil,
Messages: nil,
}
account := &Account{
ID: 123,
Type: AccountTypeOAuth,
Extra: map[string]any{}, // intentionally missing account_uuid / claude_user_id
}
fp := &Fingerprint{ClientID: "deadbeef"} // should be used as user id in legacy format
got := svc.buildOAuthMetadataUserID(parsed, account, fp)
require.NotEmpty(t, got)
// Legacy format: user_{client}_account__session_{uuid}
re := regexp.MustCompile(`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`)
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
}
func TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent(t *testing.T) {
svc := &GatewayService{}
parsed := &ParsedRequest{
Model: "claude-sonnet-4-5",
Stream: true,
MetadataUserID: "",
}
account := &Account{
ID: 123,
Type: AccountTypeOAuth,
Extra: map[string]any{
"account_uuid": "acc-uuid",
"claude_user_id": "clientid123",
"anthropic_user_id": "",
},
}
got := svc.buildOAuthMetadataUserID(parsed, account, nil)
require.NotEmpty(t, got)
// New format: user_{client}_account_{account_uuid}_session_{uuid}
re := regexp.MustCompile(`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`)
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
}
......@@ -2,6 +2,7 @@ package service
import (
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/require"
......@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
}
func TestInjectClaudeCodePrompt(t *testing.T) {
claudePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
tests := []struct {
name string
body string
......@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
system: "Custom prompt",
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom prompt",
wantSecondText: claudePrefix + "\n\nCustom prompt",
},
{
name: "string system equals Claude Code prompt",
......@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code + Custom = 2
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom",
wantSecondText: claudePrefix + "\n\nCustom",
},
{
name: "array system with existing Claude Code prompt (should dedupe)",
......@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
// Claude Code at start + Other = 2 (deduped)
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Other",
wantSecondText: claudePrefix + "\n\nOther",
},
{
name: "empty array",
......
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
in := "You are OpenCode, the best coding agent on the planet."
got := sanitizeSystemText(in)
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
}
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
in := "OpenCode and opencode are mentioned."
got := sanitizeToolDescription(in)
// We no longer rewrite tool descriptions; only redact obvious path leaks.
require.Equal(t, in, got)
}
......@@ -11,6 +11,8 @@ import (
"fmt"
"io"
"log"
"log/slog"
mathrand "math/rand"
"net/http"
"os"
"regexp"
......@@ -37,15 +39,27 @@ const (
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 40 * 1024 * 1024
// Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines)
// to match real Claude CLI traffic as closely as possible. When we need a visual
// separator between system blocks, we add "\n\n" at concatenation time.
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
)
const (
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
)
func (s *GatewayService) debugModelRoutingEnabled() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func (s *GatewayService) debugClaudeMimicEnabled() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC")))
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func shortSessionHash(sessionHash string) string {
if sessionHash == "" {
return ""
......@@ -56,6 +70,138 @@ func shortSessionHash(sessionHash string) string {
return sessionHash[:8]
}
func redactAuthHeaderValue(v string) string {
v = strings.TrimSpace(v)
if v == "" {
return ""
}
// Keep scheme for debugging, redact secret.
if strings.HasPrefix(strings.ToLower(v), "bearer ") {
return "Bearer [redacted]"
}
return "[redacted]"
}
func safeHeaderValueForLog(key string, v string) string {
key = strings.ToLower(strings.TrimSpace(key))
switch key {
case "authorization", "x-api-key":
return redactAuthHeaderValue(v)
default:
return strings.TrimSpace(v)
}
}
func extractSystemPreviewFromBody(body []byte) string {
if len(body) == 0 {
return ""
}
sys := gjson.GetBytes(body, "system")
if !sys.Exists() {
return ""
}
switch {
case sys.IsArray():
for _, item := range sys.Array() {
if !item.IsObject() {
continue
}
if strings.EqualFold(item.Get("type").String(), "text") {
if t := item.Get("text").String(); strings.TrimSpace(t) != "" {
return t
}
}
}
return ""
case sys.Type == gjson.String:
return sys.String()
default:
return ""
}
}
func buildClaudeMimicDebugLine(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) string {
if req == nil {
return ""
}
// Only log a minimal fingerprint to avoid leaking user content.
interesting := []string{
"user-agent",
"x-app",
"anthropic-dangerous-direct-browser-access",
"anthropic-version",
"anthropic-beta",
"x-stainless-lang",
"x-stainless-package-version",
"x-stainless-os",
"x-stainless-arch",
"x-stainless-runtime",
"x-stainless-runtime-version",
"x-stainless-retry-count",
"x-stainless-timeout",
"authorization",
"x-api-key",
"content-type",
"accept",
"x-stainless-helper-method",
}
h := make([]string, 0, len(interesting))
for _, k := range interesting {
if v := req.Header.Get(k); v != "" {
h = append(h, fmt.Sprintf("%s=%q", k, safeHeaderValueForLog(k, v)))
}
}
metaUserID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String())
sysPreview := strings.TrimSpace(extractSystemPreviewFromBody(body))
// Truncate preview to keep logs sane.
if len(sysPreview) > 300 {
sysPreview = sysPreview[:300] + "..."
}
sysPreview = strings.ReplaceAll(sysPreview, "\n", "\\n")
sysPreview = strings.ReplaceAll(sysPreview, "\r", "\\r")
aid := int64(0)
aname := ""
if account != nil {
aid = account.ID
aname = account.Name
}
return fmt.Sprintf(
"url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}",
req.URL.String(),
aid,
aname,
tokenType,
mimicClaudeCode,
metaUserID,
sysPreview,
strings.Join(h, " "),
)
}
func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) {
line := buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)
if line == "" {
return
}
log.Printf("[ClaudeMimicDebug] %s", line)
}
func isClaudeCodeCredentialScopeError(msg string) bool {
m := strings.ToLower(strings.TrimSpace(msg))
if m == "" {
return false
}
return strings.Contains(m, "only authorized for use with claude code") &&
strings.Contains(m, "cannot be used for other api requests")
}
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
......@@ -69,7 +215,6 @@ var (
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`)
toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`)
opencodeTextRe = regexp.MustCompile(`(?i)opencode`)
claudeToolNameOverrides = map[string]string{
"bash": "Bash",
......@@ -134,11 +279,24 @@ var allowedHeaders = map[string]bool{
"content-type": true,
}
// GatewayCache defines cache operations for gateway service
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
//
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface {
// GetSessionAccountID 获取粘性会话绑定的账号 ID
// Get the account ID bound to a sticky session
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
// SetSessionAccountID 设置粘性会话与账号的绑定关系
// Set the binding between sticky session and account
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
// RefreshSessionTTL 刷新粘性会话的过期时间
// Refresh the expiration time of a sticky session
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
......@@ -149,6 +307,28 @@ func derefGroupID(groupID *int64) int64 {
return *groupID
}
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// or within temporary unschedulable period.
// This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account) bool {
if account == nil {
return false
}
if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
return true
}
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true
}
return false
}
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
......@@ -305,6 +485,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
}
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
// Returns 0 if no binding exists or on error.
func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
if sessionHash == "" || s.cache == nil {
return 0, nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err != nil {
return 0, err
}
return accountID, nil
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
......@@ -503,12 +696,21 @@ func normalizeParamNameForOpenCode(name string, cache map[string]string) string
return name
}
func sanitizeOpenCodeText(text string) string {
// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present).
// We intentionally avoid broad keyword replacement in system prompts to prevent
// accidentally changing user-provided instructions.
func sanitizeSystemText(text string) string {
if text == "" {
return text
}
text = strings.ReplaceAll(text, "OpenCode", "Claude Code")
text = opencodeTextRe.ReplaceAllString(text, "Claude")
// Some clients include a fixed OpenCode identity sentence. Anthropic may treat
// this as a non-Claude-Code fingerprint, so rewrite it to the canonical
// Claude Code banner before generic "OpenCode"/"opencode" replacements.
text = strings.ReplaceAll(
text,
"You are OpenCode, the best coding agent on the planet.",
strings.TrimSpace(claudeCodeSystemPrompt),
)
return text
}
......@@ -518,7 +720,9 @@ func sanitizeToolDescription(description string) string {
}
description = toolDescAbsPathRe.ReplaceAllString(description, "[path]")
description = toolDescWinPathRe.ReplaceAllString(description, "[path]")
return sanitizeOpenCodeText(description)
// Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings).
// Tool names/skill names may rely on exact wording, and rewriting can be misleading.
return description
}
func normalizeToolInputSchema(inputSchema any, cache map[string]string) {
......@@ -593,7 +797,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if system, ok := req["system"]; ok {
switch v := system.(type) {
case string:
sanitized := sanitizeOpenCodeText(v)
sanitized := sanitizeSystemText(v)
if sanitized != v {
req["system"] = sanitized
}
......@@ -610,7 +814,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
if !ok || text == "" {
continue
}
sanitized := sanitizeOpenCodeText(text)
sanitized := sanitizeSystemText(text)
if sanitized != text {
block["text"] = sanitized
}
......@@ -743,17 +947,15 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
if parsed.MetadataUserID != "" {
return ""
}
accountUUID := account.GetExtraString("account_uuid")
if accountUUID == "" {
return ""
}
userID := strings.TrimSpace(account.GetClaudeUserID())
if userID == "" && fp != nil {
userID = fp.ClientID
}
if userID == "" {
return ""
// Fall back to a random, well-formed client id so we can still satisfy
// Claude Code OAuth requirements when account metadata is incomplete.
userID = generateClientID()
}
sessionHash := s.GenerateSessionHash(parsed)
......@@ -762,7 +964,14 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
sessionID = generateSessionUUID(seed)
}
// Prefer the newer format that includes account_uuid (if present),
// otherwise fall back to the legacy Claude Code format.
accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid"))
if accountUUID != "" {
return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
}
return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID)
}
func generateSessionUUID(seed string) string {
......@@ -819,11 +1028,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
// 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs {
excludedIDsList = append(excludedIDsList, id)
}
slog.Debug("account_scheduling_starting",
"group_id", derefGroupID(groupID),
"model", requestedModel,
"session", shortSessionHash(sessionHash),
"excluded_ids", excludedIDsList)
cfg := s.schedulingConfig()
// 提取会话 UUID(用于会话数量限制)
sessionUUID := extractSessionUUID(metadataUserID)
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
......@@ -849,18 +1067,39 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
// 复制排除列表,用于会话限制拒绝时的重试
localExcluded := make(map[int64]struct{})
for k, v := range excludedIDs {
localExcluded[k] = v
}
for {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
// 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位
localExcluded[account.ID] = struct{}{} // 排除此账号
continue // 重新选择
}
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
// 对于等待计划的情况,也需要先检查会话限制
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
localExcluded[account.ID] = struct{}{}
continue
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
......@@ -885,6 +1124,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
},
}, nil
}
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
if err != nil {
......@@ -999,7 +1239,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
result.ReleaseFunc() // 释放槽位
// 继续到负载感知选择
} else {
......@@ -1017,6 +1257,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
// 会话限制已满,继续到负载感知选择
} else {
return &AccountSelectionResult{
Account: stickyAccount,
WaitPlan: &AccountWaitPlan{
......@@ -1027,8 +1271,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
},
}, nil
}
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
} else {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
}
}
......@@ -1086,7 +1333,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
......@@ -1104,21 +1351,27 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
acc := routingAvailable[0].account
// 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的)
// 遍历找到第一个满足会话限制的账号
for _, item := range routingAvailable {
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
continue // 会话限制已满,尝试下一个
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID)
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
return &AccountSelectionResult{
Account: acc,
Account: item.account,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
AccountID: item.account.ID,
MaxConcurrency: item.account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
}
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
}
......@@ -1129,7 +1382,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, ok := accountByID[accountID]
if ok && s.isAccountInGroup(account, groupID) &&
if ok {
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
......@@ -1137,7 +1397,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
// Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
......@@ -1151,6 +1412,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
// Session count limit check (wait plan also requires session quota)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
} else {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
......@@ -1164,6 +1431,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
}
}
}
// ============ Layer 2: 负载感知选择 ============
candidates := make([]*Account, 0, len(accounts))
......@@ -1208,7 +1477,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
......@@ -1258,7 +1527,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
......@@ -1276,8 +1545,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// ============ Layer 3: 兜底排队 ============
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode)
for _, acc := range candidates {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
continue // 会话限制已满,尝试下一个账号
}
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
......@@ -1291,7 +1564,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts")
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
......@@ -1299,7 +1572,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
......@@ -1456,7 +1729,24 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil {
slog.Debug("account_scheduling_list_snapshot",
"group_id", derefGroupID(groupID),
"platform", platform,
"use_mixed", useMixed,
"count", len(accounts))
for _, acc := range accounts {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
}
return accounts, useMixed, err
}
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed {
......@@ -1469,6 +1759,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
slog.Debug("account_scheduling_list_failed",
"group_id", derefGroupID(groupID),
"platform", platform,
"error", err)
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
......@@ -1478,6 +1772,20 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
}
filtered = append(filtered, acc)
}
slog.Debug("account_scheduling_list_mixed",
"group_id", derefGroupID(groupID),
"platform", platform,
"raw_count", len(accounts),
"filtered_count", len(filtered))
for _, acc := range filtered {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
return filtered, useMixed, nil
}
......@@ -1492,8 +1800,25 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
slog.Debug("account_scheduling_list_failed",
"group_id", derefGroupID(groupID),
"platform", platform,
"error", err)
return nil, useMixed, err
}
slog.Debug("account_scheduling_list_single",
"group_id", derefGroupID(groupID),
"platform", platform,
"count", len(accounts))
for _, acc := range accounts {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
return accounts, useMixed, nil
}
......@@ -1559,12 +1884,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
// 缓存未命中,从数据库查询
{
var startTime time.Time
if account.SessionWindowStart != nil {
startTime = *account.SessionWindowStart
} else {
startTime = time.Now().Add(-5 * time.Hour)
}
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime := account.GetCurrentWindowStartTime()
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil {
......@@ -1597,15 +1918,16 @@ checkSchedulability:
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash)
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool {
// 只检查 Anthropic OAuth/SetupToken 账号
if !account.IsAnthropicOAuthOrSetupToken() {
return true
}
maxSessions := account.GetMaxSessions()
if maxSessions <= 0 || sessionUUID == "" {
if maxSessions <= 0 || sessionID == "" {
return true // 未启用会话限制或无会话ID
}
......@@ -1615,7 +1937,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout)
if err != nil {
// 失败开放:缓存错误时允许通过
return true
......@@ -1623,18 +1945,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
return allowed
}
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
// 格式: user_{64位hex}_account__session_{uuid}
func extractSessionUUID(metadataUserID string) string {
if metadataUserID == "" {
return ""
}
if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
return match[1]
}
return ""
}
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID)
......@@ -1664,6 +1974,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
})
}
// sortCandidatesForFallback 根据配置选择排序策略
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
if mode == "random" {
// 先按优先级排序,然后在同优先级内随机打乱
sortAccountsByPriorityOnly(accounts, preferOAuth)
shuffleWithinPriority(accounts)
} else {
// 默认按最后使用时间排序
sortAccountsByPriorityAndLastUsed(accounts, preferOAuth)
}
}
// sortAccountsByPriorityOnly 仅按优先级排序
func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
})
}
// shuffleWithinPriority 在同优先级内随机打乱顺序
func shuffleWithinPriority(accounts []*Account) {
if len(accounts) <= 1 {
return
}
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
start := 0
for start < len(accounts) {
priority := accounts[start].Priority
end := start + 1
for end < len(accounts) && accounts[end].Priority == priority {
end++
}
// 对 [start, end) 范围内的账户随机打乱
if end-start > 1 {
r.Shuffle(end-start, func(i, j int) {
accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i]
})
}
start = end
}
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
......@@ -1687,7 +2047,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
......@@ -1699,6 +2064,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
}
}
}
// 2) Select an account from the routed candidates.
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
......@@ -1784,7 +2150,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
......@@ -1793,6 +2164,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
}
}
}
// 2. 获取可调度账号列表(单平台)
if !accountsLoaded {
......@@ -1888,7 +2260,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
......@@ -1902,6 +2279,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
}
}
}
// 2) Select an account from the routed candidates.
var err error
......@@ -1987,7 +2365,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
......@@ -1998,6 +2381,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
}
}
}
// 2. 获取可调度账号列表
if !accountsLoaded {
......@@ -2247,6 +2631,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{"type": "ephemeral"},
}
// Opencode plugin applies an extra safeguard: it not only prepends the Claude Code
// banner, it also prefixes the next system instruction with the same banner plus
// a blank line. This helps when upstream concatenates system instructions.
claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
var newSystem []any
......@@ -2254,19 +2642,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
case nil:
newSystem = []any{claudeCodeBlock}
case string:
if v == "" || v == claudeCodeSystemPrompt {
// Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines.
if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) {
newSystem = []any{claudeCodeBlock}
} else {
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": v}}
// Mirror opencode behavior: keep the banner as a separate system entry,
// but also prefix the next system text with the banner.
merged := v
if !strings.HasPrefix(v, claudeCodePrefix) {
merged = claudeCodePrefix + "\n\n" + v
}
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": merged}}
}
case []any:
newSystem = make([]any, 0, len(v)+1)
newSystem = append(newSystem, claudeCodeBlock)
prefixedNext := false
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) {
continue
}
// Prefix the first subsequent text system block once.
if !prefixedNext {
if blockType, _ := m["type"].(string); blockType == "text" {
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) {
m["text"] = claudeCodePrefix + "\n\n" + text
prefixedNext = true
}
}
}
}
newSystem = append(newSystem, item)
}
......@@ -2524,6 +2929,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
proxyURL = account.Proxy.URL()
}
// 调试日志:记录即将转发的账号信息
log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
// 重试循环
var resp *http.Response
retryStart := time.Now()
......@@ -2537,7 +2946,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 发送请求
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
......@@ -2611,7 +3020,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
......@@ -2643,7 +3052,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr2 == nil {
resp = retryResp2
break
......@@ -2758,6 +3167,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 调试日志:打印重试耗尽后的错误响应
log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleRetryExhaustedSideEffects(ctx, resp, account)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
......@@ -2785,6 +3198,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 调试日志:打印上游错误响应
log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleFailoverSideEffects(ctx, resp, account)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
......@@ -2902,11 +3319,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
clientHeaders := http.Header{}
if c != nil && c.Request != nil {
clientHeaders = c.Request.Header
}
// OAuth账号:应用统一指纹
var fingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹(包含随机生成的ClientID)
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
if err != nil {
log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err)
// 失败时降级为透传原始headers
......@@ -2914,9 +3336,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
fingerprint = fp
// 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid)
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
......@@ -2936,7 +3359,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
// 白名单透传headers
for key, values := range c.Request.Header {
for key, values := range clientHeaders {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
......@@ -2964,12 +3387,18 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
if tokenType == "oauth" {
if mimicClaudeCode {
// 非 Claude Code 客户端:按 Claude Code 规则生成 beta header
if requestHasTools(body) {
req.Header.Set("anthropic-beta", claude.MessageBetaHeaderWithTools)
} else {
req.Header.Set("anthropic-beta", claude.MessageBetaHeaderNoTools)
}
// 非 Claude Code 客户端:按 opencode 的策略处理:
// - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app)
// - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在
applyClaudeCodeMimicHeaders(req, reqStream)
incomingBeta := req.Header.Get("anthropic-beta")
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
drop := map[string]struct{}{claude.BetaClaudeCode: {}}
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
} else {
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader := req.Header.Get("anthropic-beta")
......@@ -2984,6 +3413,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// Always capture a compact fingerprint line for later error diagnostics.
// We only print it when needed (or when the explicit debug flag is enabled).
if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
}
if s.debugClaudeMimicEnabled() {
logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode)
}
return req, nil
}
......@@ -3045,20 +3483,6 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false
}
func requestHasTools(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if !tools.Exists() {
return false
}
if tools.IsArray() {
return len(tools.Array()) > 0
}
if tools.IsObject() {
return len(tools.Map()) > 0
}
return false
}
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
......@@ -3087,6 +3511,73 @@ func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) {
}
}
func mergeAnthropicBeta(required []string, incoming string) string {
seen := make(map[string]struct{}, len(required)+8)
out := make([]string, 0, len(required)+8)
add := func(v string) {
v = strings.TrimSpace(v)
if v == "" {
return
}
if _, ok := seen[v]; ok {
return
}
seen[v] = struct{}{}
out = append(out, v)
}
for _, r := range required {
add(r)
}
for _, p := range strings.Split(incoming, ",") {
add(p)
}
return strings.Join(out, ",")
}
func mergeAnthropicBetaDropping(required []string, incoming string, drop map[string]struct{}) string {
merged := mergeAnthropicBeta(required, incoming)
if merged == "" || len(drop) == 0 {
return merged
}
out := make([]string, 0, 8)
for _, p := range strings.Split(merged, ",") {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if _, ok := drop[p]; ok {
continue
}
out = append(out, p)
}
return strings.Join(out, ",")
}
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials.
func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) {
if req == nil {
return
}
// Start with the standard defaults (fill missing).
applyClaudeOAuthHeaderDefaults(req, isStream)
// Then force key headers to match Claude Code fingerprint regardless of what the client sent.
for key, value := range claude.DefaultHeaders {
if value == "" {
continue
}
req.Header.Set(key, value)
}
// Real Claude CLI uses Accept: application/json (even for streaming).
req.Header.Set("accept", "application/json")
if isStream {
req.Header.Set("x-stainless-helper-method", "stream")
}
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
......@@ -3183,9 +3674,27 @@ func extractUpstreamErrorMessage(body []byte) string {
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 调试日志:打印上游错误响应
log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
// Print a compact upstream request fingerprint when we hit the Claude Code OAuth
// credential scope error. This avoids requiring env-var tweaks in a fixed deploy.
if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil {
if v, ok := c.Get(claudeMimicDebugInfoKey); ok {
if line, ok := v.(string); ok && strings.TrimSpace(line) != "" {
log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s",
resp.StatusCode,
resp.Header.Get("x-request-id"),
line,
)
}
}
}
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
......@@ -3315,6 +3824,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil {
if v, ok := c.Get(claudeMimicDebugInfoKey); ok {
if line, ok := v.(string); ok && strings.TrimSpace(line) != "" {
log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s",
resp.StatusCode,
resp.Header.Get("x-request-id"),
line,
)
}
}
}
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
......@@ -3860,17 +4382,19 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
} `json:"usage"`
}
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
// output_tokens 总是从 message_delta 获取
usage.OutputTokens = msgDelta.Usage.OutputTokens
// 如果 message_start 中没有值,则从 message_delta 获取(兼容GLM等API)
if usage.InputTokens == 0 {
// message_delta 仅覆盖存在且非0的字段
// 避免覆盖 message_start 中已有的值(如 input_tokens)
// Claude API 的 message_delta 通常只包含 output_tokens
if msgDelta.Usage.InputTokens > 0 {
usage.InputTokens = msgDelta.Usage.InputTokens
}
if usage.CacheCreationInputTokens == 0 {
if msgDelta.Usage.OutputTokens > 0 {
usage.OutputTokens = msgDelta.Usage.OutputTokens
}
if msgDelta.Usage.CacheCreationInputTokens > 0 {
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
}
if usage.CacheReadInputTokens == 0 {
if msgDelta.Usage.CacheReadInputTokens > 0 {
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
}
}
......@@ -4171,7 +4695,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 发送请求
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
......@@ -4193,7 +4717,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
resp = retryResp
respBody, err = io.ReadAll(resp.Body)
......@@ -4270,13 +4794,19 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
clientHeaders := http.Header{}
if c != nil && c.Request != nil {
clientHeaders = c.Request.Header
}
// OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
......@@ -4296,7 +4826,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// 白名单透传 headers
for key, values := range c.Request.Header {
for key, values := range clientHeaders {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
......@@ -4307,7 +4837,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用指纹到请求头
if account.IsOAuth() && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders)
if fp != nil {
s.identityService.ApplyFingerprint(req, fp)
}
......@@ -4327,7 +4857,11 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
if mimicClaudeCode {
req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
applyClaudeCodeMimicHeaders(req, false)
incomingBeta := req.Header.Get("anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta))
} else {
clientBetaHeader := req.Header.Get("anthropic-beta")
if clientBetaHeader == "" {
......@@ -4349,6 +4883,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
if c != nil && tokenType == "oauth" {
c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode))
}
if s.debugClaudeMimicEnabled() {
logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode)
}
return req, nil
}
......
......@@ -82,145 +82,276 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
}
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. 确定目标平台和调度模式
// Determine target platform and scheduling mode
platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID)
if err != nil {
return nil, err
}
cacheKey := "gemini:" + sessionHash
// 2. 尝试粘性会话命中
// Try sticky session hit
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs, platform, useMixedScheduling); account != nil {
return account, nil
}
// 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
// Query schedulable accounts (force platform mode: try group first, fallback to all)
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if len(accounts) == 0 && groupID != nil && hasForcePlatform {
accounts, err = s.listSchedulableAccountsOnce(ctx, nil, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
}
// 4. 按优先级 + LRU 选择最佳账号
// Select best account by priority + LRU
selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling)
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available Gemini accounts")
}
// 5. 设置粘性会话绑定
// Set sticky session binding
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
}
return selected, nil
}
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
// 返回:平台名称、是否使用混合调度、是否强制平台、错误。
//
// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode.
// Returns: platform name, whether to use mixed scheduling, whether force platform, error.
func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
return forcePlatform, false, true, nil
}
if groupID != nil {
// 根据分组 platform 决定查询哪种账号
var group *Group
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
group = ctxGroup
} else {
var err error
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
return "", false, false, fmt.Errorf("get group failed: %w", err)
}
}
platform = group.Platform
} else {
// 无分组时只使用原生 gemini 平台
platform = PlatformGemini
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
return group.Platform, group.Platform == PlatformGemini, false, nil
}
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
// 无分组时只使用原生 gemini 平台
return PlatformGemini, true, false, nil
}
cacheKey := "gemini:" + sessionHash
// tryStickySessionHit 尝试从粘性会话获取账号。
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
//
// tryStickySessionHit attempts to get account from sticky session.
// Returns account if hit and usable; clears session and returns nil if account unavailable.
func (s *GeminiMessagesCompatService) tryStickySessionHit(
ctx context.Context,
groupID *int64,
sessionHash, cacheKey, requestedModel string,
excludedIDs map[int64]struct{},
platform string,
useMixedScheduling bool,
) *Account {
if sessionHash == "" {
return nil
}
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
valid := false
if account.Platform == platform {
valid = true
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
valid = true
if err != nil || accountID <= 0 {
return nil
}
if valid {
usable := true
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if _, excluded := excludedIDs[accountID]; excluded {
return nil
}
account, err := s.getSchedulableAccount(ctx, accountID)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
return nil
}
if !ok {
usable = false
// 检查账号是否需要清理粘性会话
// Check if sticky session should be cleared
if shouldClearStickySession(account) {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
return nil
}
// 验证账号是否可用于当前请求
// Verify account is usable for current request
if !s.isAccountUsableForRequest(ctx, account, requestedModel, platform, useMixedScheduling) {
return nil
}
if usable {
// 刷新会话 TTL 并返回账号
// Refresh session TTL and return account
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
return account, nil
return account
}
// isAccountUsableForRequest 检查账号是否可用于当前请求。
// 验证:模型调度、模型支持、平台匹配、速率限制预检。
//
// isAccountUsableForRequest checks if account is usable for current request.
// Validates: model scheduling, model support, platform matching, rate limit precheck.
func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
ctx context.Context,
account *Account,
requestedModel, platform string,
useMixedScheduling bool,
) bool {
// 检查模型调度能力
// Check model scheduling capability
if !account.IsSchedulableForModel(requestedModel) {
return false
}
// 检查模型支持
// Check model support
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
return false
}
// 检查平台匹配
// Check platform matching
if !s.isAccountValidForPlatform(account, platform, useMixedScheduling) {
return false
}
// 速率限制预检
// Rate limit precheck
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
return false
}
return true
}
// isAccountValidForPlatform 检查账号是否匹配目标平台。
// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。
//
// isAccountValidForPlatform checks if account matches target platform.
// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling.
func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account, platform string, useMixedScheduling bool) bool {
if account.Platform == platform {
return true
}
if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
return true
}
return false
}
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
// passesRateLimitPreCheck 执行速率限制预检。
// 返回 true 表示通过预检或无需预检。
//
// passesRateLimitPreCheck performs rate limit precheck.
// Returns true if passed or precheck not required.
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
if s.rateLimitService == nil || requestedModel == "" {
return true
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if len(accounts) == 0 && groupID != nil && hasForcePlatform {
accounts, err = s.listSchedulableAccountsOnce(ctx, nil, platform, hasForcePlatform)
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
return ok
}
// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。
// 返回 nil 表示无可用账号。
//
// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred).
// Returns nil if no available account.
func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
ctx context.Context,
accounts []Account,
requestedModel string,
excludedIDs map[int64]struct{},
platform string,
useMixedScheduling bool,
) *Account {
var selected *Account
for i := range accounts {
acc := &accounts[i]
// 跳过被排除的账号
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling
// 非混合调度模式(antigravity 分组):不需要过滤
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
}
if !ok {
// 检查账号是否可用于当前请求
if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
continue
}
}
// 选择最佳账号
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
if s.isBetterGeminiAccount(acc, selected) {
selected = acc
}
}
}
}
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel)
return selected
}
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。
//
// isBetterGeminiAccount checks if candidate is better than current.
// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used.
func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *Account) bool {
// 优先级更高(数值更小)
if candidate.Priority < current.Priority {
return true
}
return nil, errors.New("no available Gemini accounts")
if candidate.Priority > current.Priority {
return false
}
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
// 同优先级,比较最后使用时间
switch {
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
// candidate 从未使用,优先
return true
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
// current 从未使用,保持
return false
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
// 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程)
return candidate.Type == AccountTypeOAuth && current.Type != AccountTypeOAuth
default:
// 都使用过,选择最久未使用的
return candidate.LastUsedAt.Before(*current.LastUsedAt)
}
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
......@@ -800,6 +931,13 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
}
// 图片生成计费
imageCount := 0
imageSize := s.extractImageSize(body)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
......@@ -807,6 +945,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Stream: req.Stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
......@@ -1240,6 +1380,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
usage = &ClaudeUsage{}
}
// 图片生成计费
imageCount := 0
imageSize := s.extractImageSize(body)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
RequestID: requestID,
Usage: *usage,
......@@ -1247,6 +1394,8 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Stream: stream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ImageCount: imageCount,
ImageSize: imageSize,
}, nil
}
......@@ -1841,6 +1990,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
var last map[string]any
var lastWithParts map[string]any
var collectedTextParts []string // Collect all text parts for aggregation
usage := &ClaudeUsage{}
for {
......@@ -1852,7 +2002,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
switch payload {
case "", "[DONE]":
if payload == "[DONE]" {
return pickGeminiCollectResult(last, lastWithParts), usage, nil
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
}
default:
var parsed map[string]any
......@@ -1871,6 +2021,12 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
// Collect text from each part for aggregation
for _, part := range parts {
if text, ok := part["text"].(string); ok && text != "" {
collectedTextParts = append(collectedTextParts, text)
}
}
}
}
}
......@@ -1885,7 +2041,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
}
}
return pickGeminiCollectResult(last, lastWithParts), usage, nil
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
}
func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any {
......@@ -1898,6 +2054,83 @@ func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any)
return map[string]any{}
}
// mergeCollectedTextParts merges all collected text chunks into the final response.
// This fixes the issue where non-streaming responses only returned the last chunk
// instead of the complete aggregated text.
func mergeCollectedTextParts(response map[string]any, textParts []string) map[string]any {
if len(textParts) == 0 {
return response
}
// Join all text parts
mergedText := strings.Join(textParts, "")
// Deep copy response
result := make(map[string]any)
for k, v := range response {
result[k] = v
}
// Get or create candidates
candidates, ok := result["candidates"].([]any)
if !ok || len(candidates) == 0 {
candidates = []any{map[string]any{}}
}
// Get first candidate
candidate, ok := candidates[0].(map[string]any)
if !ok {
candidate = make(map[string]any)
candidates[0] = candidate
}
// Get or create content
content, ok := candidate["content"].(map[string]any)
if !ok {
content = map[string]any{"role": "model"}
candidate["content"] = content
}
// Get existing parts
existingParts, ok := content["parts"].([]any)
if !ok {
existingParts = []any{}
}
// Find and update first text part, or create new one
newParts := make([]any, 0, len(existingParts)+1)
textUpdated := false
for _, p := range existingParts {
pm, ok := p.(map[string]any)
if !ok {
newParts = append(newParts, p)
continue
}
if _, hasText := pm["text"]; hasText && !textUpdated {
// Replace with merged text
newPart := make(map[string]any)
for k, v := range pm {
newPart[k] = v
}
newPart["text"] = mergedText
newParts = append(newParts, newPart)
textUpdated = true
} else {
newParts = append(newParts, pm)
}
}
if !textUpdated {
newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
}
content["parts"] = newParts
result["candidates"] = candidates
return result
}
type geminiNativeStreamResult struct {
usage *ClaudeUsage
firstTokenMs *int
......@@ -2816,3 +3049,26 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any {
}
return out
}
// extractImageSize 从 Gemini 请求中提取 image_size 参数
func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string {
var req struct {
GenerationConfig *struct {
ImageConfig *struct {
ImageSize string `json:"imageSize"`
} `json:"imageConfig"`
} `json:"generationConfig"`
}
if err := json.Unmarshal(body, &req); err != nil {
return "2K"
}
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
if size == "1K" || size == "2K" || size == "4K" {
return size
}
}
return "2K"
}
......@@ -17,6 +17,8 @@ import (
type mockAccountRepoForGemini struct {
accounts []Account
accountsByID map[int64]*Account
listByGroupFunc func(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
listByPlatformFunc func(ctx context.Context, platforms []string) ([]Account, error)
}
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
......@@ -88,6 +90,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearError(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
......@@ -104,6 +109,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context,
return nil, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
if m.listByPlatformFunc != nil {
return m.listByPlatformFunc(ctx, platforms)
}
var result []Account
platformSet := make(map[string]bool)
for _, p := range platforms {
......@@ -117,6 +125,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex
return result, nil
}
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
if m.listByGroupFunc != nil {
return m.listByGroupFunc(ctx, groupID, platforms)
}
return m.ListSchedulableByPlatforms(ctx, platforms)
}
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
......@@ -212,6 +223,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
type mockGatewayCacheForGemini struct {
sessionBindings map[string]int64
deletedSessions map[string]int
}
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
......@@ -233,6 +245,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group
return nil
}
func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
if m.sessionBindings == nil {
return nil
}
if m.deletedSessions == nil {
m.deletedSessions = make(map[string]int)
}
m.deletedSessions[sessionHash]++
delete(m.sessionBindings, sessionHash)
return nil
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
ctx := context.Background()
......@@ -523,6 +547,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS
// 粘性会话未命中,按优先级选择
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
})
t.Run("粘性会话不可调度-清理并回退选择", func(t *testing.T) {
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusDisabled, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-123": 1},
}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
require.Equal(t, 1, cache.deletedSessions["gemini:session-123"])
require.Equal(t, int64(2), cache.sessionBindings["gemini:session-123"])
})
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ForcePlatformFallback(t *testing.T) {
ctx := context.Background()
groupID := int64(9)
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity)
repo := &mockAccountRepoForGemini{
listByGroupFunc: func(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
return nil, nil
},
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
return []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
}, nil
},
accountsByID: map[int64]*Account{
1: {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
},
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{
ID: 1,
Platform: PlatformGemini,
Priority: 1,
Status: StatusActive,
Schedulable: true,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.0-pro": "gemini-1.0-pro"}},
},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "supporting model")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyMixedScheduling(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{
sessionBindings: map[string]int64{"gemini:session-999": 1},
}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-999", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(1), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_SkipDisabledMixedScheduling(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludedAccount(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
excluded := map[int64]struct{}{1: {}}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", excluded)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ListError(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
return nil, errors.New("query failed")
},
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
require.Error(t, err)
require.Nil(t, acc)
require.Contains(t, err.Error(), "query accounts failed")
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferOAuth(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferLeastRecentlyUsed(t *testing.T) {
ctx := context.Background()
oldTime := time.Now().Add(-2 * time.Hour)
newTime := time.Now().Add(-1 * time.Hour)
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &newTime},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &oldTime},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID)
}
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
......@@ -599,7 +891,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
name: "Gemini平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformGemini,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}},
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "x"}},
},
model: "gemini-2.5-flash",
expected: false,
......
package service
import (
"encoding/json"
)
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
// 以避免跨账号签名验证错误。
//
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
//
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
// to avoid cross-account signature validation errors.
//
// When sticky session switches accounts (e.g., original account becomes unavailable),
// thoughtSignatures from the old account will cause validation failures on the new account.
// By removing these signatures, we allow the new account to generate valid signatures.
func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
if len(body) == 0 {
return body
}
// 解析 JSON
var data any
if err := json.Unmarshal(body, &data); err != nil {
// 如果解析失败,返回原始 body(可能不是 JSON 或格式不正确)
return body
}
// 递归清理 thoughtSignature
cleaned := cleanThoughtSignaturesRecursive(data)
// 重新序列化
result, err := json.Marshal(cleaned)
if err != nil {
// 如果序列化失败,返回原始 body
return body
}
return result
}
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
func cleanThoughtSignaturesRecursive(data any) any {
switch v := data.(type) {
case map[string]any:
// 创建新的 map,移除 thoughtSignature
result := make(map[string]any, len(v))
for key, value := range v {
// 跳过 thoughtSignature 字段
if key == "thoughtSignature" {
continue
}
// 递归处理嵌套结构
result[key] = cleanThoughtSignaturesRecursive(value)
}
return result
case []any:
// 递归处理数组中的每个元素
result := make([]any, len(v))
for i, item := range v {
result[i] = cleanThoughtSignaturesRecursive(item)
}
return result
default:
// 基本类型(string, number, bool, null)直接返回
return v
}
}
......@@ -4,6 +4,7 @@ import (
"context"
"errors"
"log"
"log/slog"
"strconv"
"strings"
"time"
......@@ -131,8 +132,18 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
// 3) Populate cache with TTL.
// 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
// 版本过时,使用 DB 中的最新 token
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
// 不写入缓存,让下次请求重新处理
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
......@@ -147,6 +158,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
}
return accessToken, nil
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment