Commit 7331220e authored by Edric Li's avatar Edric Li
Browse files

Merge remote-tracking branch 'upstream/main'

# Conflicts:
#	frontend/src/components/account/CreateAccountModal.vue
parents fb86002e 4f13c8de
......@@ -3,6 +3,7 @@ package service
import (
"encoding/json"
"strconv"
"strings"
"time"
)
......@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}
func (a *Account) GeminiOAuthType() string {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return ""
}
oauthType := strings.TrimSpace(a.GetCredential("oauth_type"))
if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" {
return "code_assist"
}
return oauthType
}
func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
if tierID == "" {
return ""
}
return strings.ToUpper(tierID)
}
func (a *Account) IsGeminiCodeAssist() bool {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return false
}
oauthType := a.GeminiOAuthType()
if oauthType == "" {
return strings.TrimSpace(a.GetCredential("project_id")) != ""
}
return oauthType == "code_assist"
}
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}
......@@ -110,6 +141,28 @@ func (a *Account) GetCredential(key string) string {
}
}
// GetCredentialAsTime 解析凭证中的时间戳字段,支持多种格式
// 兼容以下格式:
// - RFC3339 字符串: "2025-01-01T00:00:00Z"
// - Unix 时间戳字符串: "1735689600"
// - Unix 时间戳数字: 1735689600 (float64/int64/json.Number)
func (a *Account) GetCredentialAsTime(key string) *time.Time {
s := a.GetCredential(key)
if s == "" {
return nil
}
// 尝试 RFC3339 格式
if t, err := time.Parse(time.RFC3339, s); err == nil {
return &t
}
// 尝试 Unix 时间戳(纯数字字符串)
if ts, err := strconv.ParseInt(s, 10, 64); err == nil {
t := time.Unix(ts, 0)
return &t
}
return nil
}
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
......@@ -324,19 +377,7 @@ func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() {
return nil
}
expiresAtStr := a.GetCredential("expires_at")
if expiresAtStr == "" {
return nil
}
t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil {
if v, ok := a.Credentials["expires_at"].(float64); ok {
tt := time.Unix(int64(v), 0)
return &tt
}
return nil
}
return &t
return a.GetCredentialAsTime("expires_at")
}
func (a *Account) IsOpenAITokenExpired() bool {
......
......@@ -5,12 +5,13 @@ import (
"fmt"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
var (
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
)
type AccountRepository interface {
......
......@@ -12,7 +12,6 @@ import (
"log"
"net/http"
"regexp"
"strconv"
"strings"
"time"
......@@ -187,9 +186,8 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
// Check if token needs refresh
needRefresh := false
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err == nil && time.Now().Unix()+300 > expiresAt {
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
if time.Now().Add(5 * time.Minute).After(*expiresAt) {
needRefresh = true
}
}
......@@ -263,7 +261,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL)
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
......@@ -378,7 +376,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL)
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
......@@ -449,7 +447,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL)
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
......
......@@ -52,6 +52,9 @@ type UsageLogRepository interface {
// Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
}
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
......@@ -94,6 +97,8 @@ type UsageInfo struct {
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
}
// ClaudeUsageResponse Anthropic API返回的usage结构
......@@ -122,14 +127,16 @@ type AccountUsageService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService
}
// NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
return &AccountUsageService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService,
}
}
......@@ -143,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("get account failed: %w", err)
}
if account.Platform == PlatformGemini {
return s.getGeminiUsage(ctx, account)
}
// 只有oauth类型账号可以通过API获取usage(有profile scope)
if account.CanGetUsage() {
var apiResp *ClaudeUsageResponse
......@@ -189,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
}
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
now := time.Now()
usage := &UsageInfo{
UpdatedAt: &now,
}
if s.geminiQuotaService == nil || s.usageLogRepo == nil {
return usage, nil
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
return usage, nil
}
start := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
totals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
return usage, nil
}
// addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存(1 分钟),与 API 缓存分离
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
......@@ -385,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
// Setup Token无法获取7d数据
return info
}
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
if limit <= 0 {
return nil
}
utilization := (float64(used) / float64(limit)) * 100
remainingSeconds := int(resetAt.Sub(now).Seconds())
if remainingSeconds < 0 {
remainingSeconds = 0
}
resetCopy := resetAt
return &UsageProgress{
Utilization: utilization,
ResetsAt: &resetCopy,
RemainingSeconds: remainingSeconds,
WindowStats: &WindowStats{
Requests: used,
Tokens: tokens,
Cost: cost,
},
}
}
......@@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
subscriptionType = SubscriptionTypeStandard
}
// 限额字段:0 和 nil 都表示"无限制"
dailyLimit := normalizeLimit(input.DailyLimitUSD)
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
group := &Group{
Name: input.Name,
Description: input.Description,
......@@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
IsExclusive: input.IsExclusive,
Status: StatusActive,
SubscriptionType: subscriptionType,
DailyLimitUSD: input.DailyLimitUSD,
WeeklyLimitUSD: input.WeeklyLimitUSD,
MonthlyLimitUSD: input.MonthlyLimitUSD,
DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
......@@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil
}
// normalizeLimit 将 0 或负数转换为 nil(表示无限制)
func normalizeLimit(limit *float64) *float64 {
if limit == nil || *limit <= 0 {
return nil
}
return limit
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
......@@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType
}
// 限额字段支持设置为nil(清除限额)或具体值
// 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
if input.DailyLimitUSD != nil {
group.DailyLimitUSD = input.DailyLimitUSD
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
}
if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = input.WeeklyLimitUSD
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
}
if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = input.MonthlyLimitUSD
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
}
if err := s.groupRepo.Update(ctx, group); err != nil {
......
......@@ -25,7 +25,7 @@ const (
antigravityRetryMaxDelay = 16 * time.Second
)
// Antigravity 直接支持的模型
// Antigravity 直接支持的模型(精确匹配透传)
var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true,
......@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{
"gemini-3-flash": true,
"gemini-3-pro-low": true,
"gemini-3-pro-high": true,
"gemini-3-pro-preview": true,
"gemini-3-pro-image": true,
}
// Antigravity 系统默认模型映射表(不支持 → 支持)
var antigravityModelMapping = map[string]string{
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5",
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking",
"claude-opus-4": "claude-opus-4-5-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking",
"claude-haiku-4": "gemini-3-flash",
"claude-haiku-4-5": "gemini-3-flash",
"claude-3-haiku-20240307": "gemini-3-flash",
"claude-haiku-4-5-20251001": "gemini-3-flash",
// 生图模型:官方名 → Antigravity 内部名
"gemini-3-pro-image-preview": "gemini-3-pro-image",
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
var antigravityPrefixMapping = []struct {
prefix string
target string
}{
// 长前缀优先
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
{"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx
{"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "gemini-3-flash"},
{"claude-opus-4", "claude-opus-4-5-thinking"},
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
......@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider
}
// getMappedModel 获取映射后的模型名
// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
// 1. 优先使用账户级映射(复用现有方法
// 1. 账户级映射(用户自定义优先
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
return mapped
}
// 2. 系统默认映射
if mapped, ok := antigravityModelMapping[requestedModel]; ok {
return mapped
// 2. 直接支持的模型透传
if antigravitySupportedModels[requestedModel] {
return requestedModel
}
// 3. Gemini 模型透传
if strings.HasPrefix(requestedModel, "gemini-") {
return requestedModel
// 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview)
for _, pm := range antigravityPrefixMapping {
if strings.HasPrefix(requestedModel, pm.prefix) {
return pm.target
}
}
// 4. Claude 前缀透传直接支持的模型
if antigravitySupportedModels[requestedModel] {
// 4. Gemini 模型透传(未匹配到前缀的 gemini 模型
if strings.HasPrefix(requestedModel, "gemini-") {
return requestedModel
}
......@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
}
// IsModelSupported 检查模型是否被支持
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
// 直接支持的模型
if antigravitySupportedModels[requestedModel] {
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射)
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// TestConnectionResult 测试连接结果
......@@ -180,7 +172,7 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
}
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL)
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
......@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err)
}
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
truncated := string(bodyJSON)
if len(truncated) > 2000 {
truncated = truncated[:2000] + "..."
}
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
}
// 构建上游 action
action := "generateContent"
if claudeReq.Stream {
......@@ -372,7 +373,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, err
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
......@@ -515,7 +516,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, err
}
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
......
......@@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
expected: "claude-sonnet-4-5",
},
// 3. Gemini 透传
......
......@@ -191,7 +191,7 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc
// isTokenExpired 检查 token 是否过期
func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool {
expiresAt := parseAntigravityExpiresAt(account)
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
return false
}
......
......@@ -55,7 +55,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
}
// 2. 如果即将过期则刷新
expiresAt := parseAntigravityExpiresAt(account)
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
......@@ -72,7 +72,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if err == nil && fresh != nil {
account = fresh
}
expiresAt = parseAntigravityExpiresAt(account)
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
if p.antigravityOAuthService == nil {
return "", errors.New("antigravity oauth service not configured")
......@@ -91,7 +91,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
}
expiresAt = parseAntigravityExpiresAt(account)
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
......@@ -128,18 +128,3 @@ func antigravityTokenCacheKey(account *Account) string {
}
return "ag:account:" + strconv.FormatInt(account.ID, 10)
}
func parseAntigravityExpiresAt(account *Account) *time.Time {
raw := strings.TrimSpace(account.GetCredential("expires_at"))
if raw == "" {
return nil
}
if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 {
t := time.Unix(unixSec, 0)
return &t
}
if t, err := time.Parse(time.RFC3339, raw); err == nil {
return &t
}
return nil
}
......@@ -2,7 +2,7 @@ package service
import (
"context"
"strconv"
"fmt"
"time"
)
......@@ -29,21 +29,22 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
}
// NeedsRefresh 检查账户是否需要刷新
// Antigravity 使用固定的10分钟刷新窗口,忽略全局配置
// Antigravity 使用固定的15分钟刷新窗口,忽略全局配置
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
if !r.CanRefresh(account) {
return false
}
expiresAtStr := account.GetCredential("expires_at")
if expiresAtStr == "" {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
return false
}
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err != nil {
return false
timeUntilExpiry := time.Until(*expiresAt)
needsRefresh := timeUntilExpiry < antigravityRefreshWindow
if needsRefresh {
fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n",
account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow)
}
expiryTime := time.Unix(expiresAt, 0)
return time.Until(expiryTime) < antigravityRefreshWindow
return needsRefresh
}
// Refresh 执行 token 刷新
......
......@@ -8,7 +8,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)
......
......@@ -8,7 +8,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
......
......@@ -4,10 +4,12 @@ import (
"context"
"fmt"
"log"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// 错误定义
......@@ -27,6 +29,46 @@ type subscriptionCacheData struct {
Version int64
}
// 缓存写入任务类型
type cacheWriteKind int
const (
cacheWriteSetBalance cacheWriteKind = iota
cacheWriteSetSubscription
cacheWriteUpdateSubscriptionUsage
cacheWriteDeductBalance
)
// 异步缓存写入工作池配置
//
// 性能优化说明:
// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题:
// 1. 每次请求创建新 goroutine,高并发下产生大量短生命周期 goroutine
// 2. 无法控制并发数量,可能导致 Redis 连接耗尽
// 3. goroutine 创建/销毁带来额外开销
//
// 新实现使用固定大小的工作池:
// 1. 预创建 10 个 worker goroutine,避免频繁创建销毁
// 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值
// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警
// 4. 统一超时控制,避免慢操作阻塞工作池
const (
cacheWriteWorkerCount = 10 // 工作协程数量
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
)
// cacheWriteTask 缓存写入任务
type cacheWriteTask struct {
kind cacheWriteKind
userID int64
groupID int64
balance float64
amount float64
subscriptionData *subscriptionCacheData
}
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
......@@ -34,16 +76,151 @@ type BillingCacheService struct {
userRepo UserRepository
subRepo UserSubscriptionRepository
cfg *config.Config
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
cacheWriteStopOnce sync.Once
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount uint64
cacheWriteDropFullLastLog int64
cacheWriteDropClosedCount uint64
cacheWriteDropClosedLastLog int64
}
// NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
return &BillingCacheService{
svc := &BillingCacheService{
cache: cache,
userRepo: userRepo,
subRepo: subRepo,
cfg: cfg,
}
svc.startCacheWriteWorkers()
return svc
}
// Stop 关闭缓存写入工作池
func (s *BillingCacheService) Stop() {
s.cacheWriteStopOnce.Do(func() {
if s.cacheWriteChan == nil {
return
}
close(s.cacheWriteChan)
s.cacheWriteWg.Wait()
s.cacheWriteChan = nil
})
}
func (s *BillingCacheService) startCacheWriteWorkers() {
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
for i := 0; i < cacheWriteWorkerCount; i++ {
s.cacheWriteWg.Add(1)
go s.cacheWriteWorker()
}
}
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
if s.cacheWriteChan == nil {
return false
}
defer func() {
if recovered := recover(); recovered != nil {
// 队列已关闭时可能触发 panic,记录后静默失败。
s.logCacheWriteDrop(task, "closed")
enqueued = false
}
}()
select {
case s.cacheWriteChan <- task:
return true
default:
// 队列满时不阻塞主流程,交由调用方决定是否同步回退。
s.logCacheWriteDrop(task, "full")
return false
}
}
func (s *BillingCacheService) cacheWriteWorker() {
defer s.cacheWriteWg.Done()
for task := range s.cacheWriteChan {
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
switch task.kind {
case cacheWriteSetBalance:
s.setBalanceCache(ctx, task.userID, task.balance)
case cacheWriteSetSubscription:
s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData)
case cacheWriteUpdateSubscriptionUsage:
if s.cache != nil {
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
}
}
case cacheWriteDeductBalance:
if s.cache != nil {
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err)
}
}
}
cancel()
}
}
// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。
func cacheWriteKindName(kind cacheWriteKind) string {
switch kind {
case cacheWriteSetBalance:
return "set_balance"
case cacheWriteSetSubscription:
return "set_subscription"
case cacheWriteUpdateSubscriptionUsage:
return "update_subscription_usage"
case cacheWriteDeductBalance:
return "deduct_balance"
default:
return "unknown"
}
}
// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。
func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) {
var (
countPtr *uint64
lastPtr *int64
)
switch reason {
case "full":
countPtr = &s.cacheWriteDropFullCount
lastPtr = &s.cacheWriteDropFullLastLog
case "closed":
countPtr = &s.cacheWriteDropClosedCount
lastPtr = &s.cacheWriteDropClosedLastLog
default:
return
}
atomic.AddUint64(countPtr, 1)
now := time.Now().UnixNano()
last := atomic.LoadInt64(lastPtr)
if now-last < int64(cacheWriteDropLogInterval) {
return
}
if !atomic.CompareAndSwapInt64(lastPtr, last, now) {
return
}
dropped := atomic.SwapUint64(countPtr, 0)
if dropped == 0 {
return
}
log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
reason,
dropped,
cacheWriteDropLogInterval,
cacheWriteKindName(task.kind),
task.userID,
task.groupID,
)
}
// ============================================
......@@ -70,11 +247,11 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
}
// 异步建立缓存
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
s.setBalanceCache(cacheCtx, userID, balance)
}()
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetBalance,
userID: userID,
balance: balance,
})
return balance, nil
}
......@@ -98,7 +275,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64,
}
}
// DeductBalanceCache 扣减余额缓存(步调用,用于扣费后更新缓存
// DeductBalanceCache 扣减余额缓存(步调用)
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
if s.cache == nil {
return nil
......@@ -106,6 +283,26 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
return s.cache.DeductUserBalance(ctx, userID, amount)
}
// QueueDeductBalance 异步扣减余额缓存
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
if s.cache == nil {
return
}
// 队列满时同步回退,避免关键扣减被静默丢弃。
if s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteDeductBalance,
userID: userID,
amount: amount,
}) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
}
}
// InvalidateUserBalance 失效用户余额缓存
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
if s.cache == nil {
......@@ -141,11 +338,12 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
}
// 异步建立缓存
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
s.setSubscriptionCache(cacheCtx, userID, groupID, data)
}()
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetSubscription,
userID: userID,
groupID: groupID,
subscriptionData: data,
})
return data, nil
}
......@@ -199,7 +397,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID,
}
}
// UpdateSubscriptionUsage 更新订阅用量缓存(步调用,用于扣费后更新缓存
// UpdateSubscriptionUsage 更新订阅用量缓存(步调用)
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
if s.cache == nil {
return nil
......@@ -207,6 +405,27 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
}
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
if s.cache == nil {
return
}
// 队列满时同步回退,确保订阅用量及时更新。
if s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteUpdateSubscriptionUsage,
userID: userID,
groupID: groupID,
amount: costUSD,
}) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
}
}
// InvalidateSubscription 失效指定订阅缓存
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
if s.cache == nil {
......
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type billingCacheWorkerStub struct {
balanceUpdates int64
subscriptionUpdates int64
}
func (b *billingCacheWorkerStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
atomic.AddInt64(&b.balanceUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
atomic.AddInt64(&b.balanceUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
return nil
}
func (b *billingCacheWorkerStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
return nil, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
atomic.AddInt64(&b.subscriptionUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
atomic.AddInt64(&b.subscriptionUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
return nil
}
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
start := time.Now()
for i := 0; i < cacheWriteBufferSize*2; i++ {
svc.QueueDeductBalance(1, 1)
}
require.Less(t, time.Since(start), 2*time.Second)
svc.QueueUpdateSubscriptionUsage(1, 2, 1.5)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&cache.balanceUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
}
......@@ -9,24 +9,35 @@ import (
"time"
)
// ConcurrencyCache defines cache operations for concurrency service
// Uses independent keys per request slot with native Redis TTL for automatic cleanup
// ConcurrencyCache 定义并发控制的缓存接口
// 使用有序集合存储槽位,按时间戳清理过期条目
type ConcurrencyCache interface {
// Account slot management - each slot is a separate key with independent TTL
// Key format: concurrency:account:{accountID}:{requestID}
// 账号槽位管理
// 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
// User slot management - each slot is a separate key with independent TTL
// Key format: concurrency:user:{userID}:{requestID}
// 账号等待队列(账号级)
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
// 用户槽位管理
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
// Wait queue - uses counter with TTL set only on creation
// 等待队列计数(只在首次创建时设置 TTL)
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
}
// generateRequestID generates a unique request ID for concurrency slot tracking
......@@ -61,6 +72,18 @@ type AcquireResult struct {
ReleaseFunc func() // Must be called when done (typically via defer)
}
type AccountWithConcurrency struct {
ID int64
MaxConcurrency int
}
type AccountLoadInfo struct {
AccountID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
......@@ -177,6 +200,42 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
}
}
// IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil {
return true, nil
}
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
if err != nil {
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
return true, nil
}
return result, nil
}
// DecrementAccountWaitCount decrements the wait queue counter for an account.
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
if s.cache == nil {
return
}
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
}
}
// GetAccountWaitingCount gets current wait queue count for an account.
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetAccountWaitingCount(ctx, accountID)
}
// CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int {
......@@ -186,6 +245,57 @@ func CalculateMaxWait(userConcurrency int) int {
return userConcurrency + defaultExtraWaitSlots
}
// GetAccountsLoadBatch returns load info for multiple accounts.
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if s.cache == nil {
return map[int64]*AccountLoadInfo{}, nil
}
return s.cache.GetAccountsLoadBatch(ctx, accounts)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
if s.cache == nil {
return nil
}
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
}
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
return
}
runCleanup := func() {
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
accounts, err := accountRepo.ListSchedulable(listCtx)
cancel()
if err != nil {
log.Printf("Warning: list schedulable accounts failed: %v", err)
return
}
for _, account := range accounts {
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
accountCancel()
if err != nil {
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
}
}
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
runCleanup()
for range ticker.C {
runCleanup()
}
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
......
......@@ -12,6 +12,8 @@ import (
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
type CRSSyncService struct {
......@@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
return nil, errors.New("username and password are required")
}
client := &http.Client{Timeout: 20 * time.Second}
client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second,
})
if err != nil {
client = &http.Client{Timeout: 20 * time.Second}
}
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
if err != nil {
......
......@@ -91,6 +91,9 @@ const (
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
// Gemini 配额策略(JSON)
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
)
// Admin API Key prefix (distinct from user "sk-" keys)
......
......@@ -10,7 +10,7 @@ import (
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (
......
......@@ -261,6 +261,34 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
}
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
}
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
ctx := context.Background()
......@@ -576,6 +604,32 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
ctx := context.Background()
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
})
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
......@@ -783,3 +837,160 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
})
}
}
// mockConcurrencyService for testing
type mockConcurrencyService struct {
accountLoads map[int64]*AccountLoadInfo
accountWaitCounts map[int64]int
acquireResults map[int64]bool
}
func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if m.accountLoads == nil {
return map[int64]*AccountLoadInfo{}, nil
}
result := make(map[int64]*AccountLoadInfo)
for _, acc := range accounts {
if load, ok := m.accountLoads[acc.ID]; ok {
result[acc.ID] = load
} else {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
}
return result, nil
}
func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if m.accountWaitCounts == nil {
return 0, nil
}
return m.accountWaitCounts[accountID], nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // No concurrency service
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("排除账号-不选择被排除的账号", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
})
t.Run("无可用账号-返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
})
}
package service
import (
"encoding/json"
"fmt"
)
// ParsedRequest 保存网关请求的预解析结果
//
// 性能优化说明:
// 原实现在多个位置重复解析请求体(Handler、Service 各解析一次):
// 1. gateway_handler.go 解析获取 model 和 stream
// 2. gateway_service.go 再次解析获取 system、messages、metadata
// 3. GenerateSessionHash 又一次解析获取会话哈希所需字段
//
// 新实现一次解析,多处复用:
// 1. 在 Handler 层统一调用 ParseGatewayRequest 一次性解析
// 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
type ParsedRequest struct {
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id(用于会话亲和)
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
parsed := &ParsedRequest{
Body: body,
}
if rawModel, exists := req["model"]; exists {
model, ok := rawModel.(string)
if !ok {
return nil, fmt.Errorf("invalid model field type")
}
parsed.Model = model
}
if rawStream, exists := req["stream"]; exists {
stream, ok := rawStream.(bool)
if !ok {
return nil, fmt.Errorf("invalid stream field type")
}
parsed.Stream = stream
}
if metadata, ok := req["metadata"].(map[string]any); ok {
if userID, ok := metadata["user_id"].(string); ok {
parsed.MetadataUserID = userID
}
}
// system 字段只要存在就视为显式提供(即使为 null),
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
return parsed, nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment