Commit 7331220e authored by Edric Li's avatar Edric Li
Browse files

Merge remote-tracking branch 'upstream/main'

# Conflicts:
#	frontend/src/components/account/CreateAccountModal.vue
parents fb86002e 4f13c8de
...@@ -3,6 +3,7 @@ package service ...@@ -3,6 +3,7 @@ package service
import ( import (
"encoding/json" "encoding/json"
"strconv" "strconv"
"strings"
"time" "time"
) )
...@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool { ...@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini return a.Platform == PlatformGemini
} }
func (a *Account) GeminiOAuthType() string {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return ""
}
oauthType := strings.TrimSpace(a.GetCredential("oauth_type"))
if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" {
return "code_assist"
}
return oauthType
}
func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
if tierID == "" {
return ""
}
return strings.ToUpper(tierID)
}
func (a *Account) IsGeminiCodeAssist() bool {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return false
}
oauthType := a.GeminiOAuthType()
if oauthType == "" {
return strings.TrimSpace(a.GetCredential("project_id")) != ""
}
return oauthType == "code_assist"
}
func (a *Account) CanGetUsage() bool { func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth return a.Type == AccountTypeOAuth
} }
...@@ -110,6 +141,28 @@ func (a *Account) GetCredential(key string) string { ...@@ -110,6 +141,28 @@ func (a *Account) GetCredential(key string) string {
} }
} }
// GetCredentialAsTime 解析凭证中的时间戳字段,支持多种格式
// 兼容以下格式:
// - RFC3339 字符串: "2025-01-01T00:00:00Z"
// - Unix 时间戳字符串: "1735689600"
// - Unix 时间戳数字: 1735689600 (float64/int64/json.Number)
func (a *Account) GetCredentialAsTime(key string) *time.Time {
s := a.GetCredential(key)
if s == "" {
return nil
}
// 尝试 RFC3339 格式
if t, err := time.Parse(time.RFC3339, s); err == nil {
return &t
}
// 尝试 Unix 时间戳(纯数字字符串)
if ts, err := strconv.ParseInt(s, 10, 64); err == nil {
t := time.Unix(ts, 0)
return &t
}
return nil
}
func (a *Account) GetModelMapping() map[string]string { func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil { if a.Credentials == nil {
return nil return nil
...@@ -324,19 +377,7 @@ func (a *Account) GetOpenAITokenExpiresAt() *time.Time { ...@@ -324,19 +377,7 @@ func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() { if !a.IsOpenAIOAuth() {
return nil return nil
} }
expiresAtStr := a.GetCredential("expires_at") return a.GetCredentialAsTime("expires_at")
if expiresAtStr == "" {
return nil
}
t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil {
if v, ok := a.Credentials["expires_at"].(float64); ok {
tt := time.Unix(int64(v), 0)
return &tt
}
return nil
}
return &t
} }
func (a *Account) IsOpenAITokenExpired() bool { func (a *Account) IsOpenAITokenExpired() bool {
......
...@@ -5,12 +5,13 @@ import ( ...@@ -5,12 +5,13 @@ import (
"fmt" "fmt"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
var ( var (
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found") ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
) )
type AccountRepository interface { type AccountRepository interface {
......
...@@ -12,7 +12,6 @@ import ( ...@@ -12,7 +12,6 @@ import (
"log" "log"
"net/http" "net/http"
"regexp" "regexp"
"strconv"
"strings" "strings"
"time" "time"
...@@ -187,9 +186,8 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account ...@@ -187,9 +186,8 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
// Check if token needs refresh // Check if token needs refresh
needRefresh := false needRefresh := false
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" { if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) if time.Now().Add(5 * time.Minute).After(*expiresAt) {
if err == nil && time.Now().Unix()+300 > expiresAt {
needRefresh = true needRefresh = true
} }
} }
...@@ -263,7 +261,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account ...@@ -263,7 +261,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
resp, err := s.httpUpstream.Do(req, proxyURL) resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
...@@ -378,7 +376,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account ...@@ -378,7 +376,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
resp, err := s.httpUpstream.Do(req, proxyURL) resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
...@@ -449,7 +447,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account ...@@ -449,7 +447,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
resp, err := s.httpUpstream.Do(req, proxyURL) resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }
......
...@@ -52,6 +52,9 @@ type UsageLogRepository interface { ...@@ -52,6 +52,9 @@ type UsageLogRepository interface {
// Aggregated stats (optimized) // Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
} }
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) // apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
...@@ -90,10 +93,12 @@ type UsageProgress struct { ...@@ -90,10 +93,12 @@ type UsageProgress struct {
// UsageInfo 账号使用量信息 // UsageInfo 账号使用量信息
type UsageInfo struct { type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口 FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口 SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口 SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
} }
// ClaudeUsageResponse Anthropic API返回的usage结构 // ClaudeUsageResponse Anthropic API返回的usage结构
...@@ -119,17 +124,19 @@ type ClaudeUsageFetcher interface { ...@@ -119,17 +124,19 @@ type ClaudeUsageFetcher interface {
// AccountUsageService 账号使用量查询服务 // AccountUsageService 账号使用量查询服务
type AccountUsageService struct { type AccountUsageService struct {
accountRepo AccountRepository accountRepo AccountRepository
usageLogRepo UsageLogRepository usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService
} }
// NewAccountUsageService 创建AccountUsageService实例 // NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService { func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
return &AccountUsageService{ return &AccountUsageService{
accountRepo: accountRepo, accountRepo: accountRepo,
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher, usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService,
} }
} }
...@@ -143,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U ...@@ -143,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("get account failed: %w", err) return nil, fmt.Errorf("get account failed: %w", err)
} }
if account.Platform == PlatformGemini {
return s.getGeminiUsage(ctx, account)
}
// 只有oauth类型账号可以通过API获取usage(有profile scope) // 只有oauth类型账号可以通过API获取usage(有profile scope)
if account.CanGetUsage() { if account.CanGetUsage() {
var apiResp *ClaudeUsageResponse var apiResp *ClaudeUsageResponse
...@@ -189,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U ...@@ -189,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("account type %s does not support usage query", account.Type) return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
} }
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
now := time.Now()
usage := &UsageInfo{
UpdatedAt: &now,
}
if s.geminiQuotaService == nil || s.usageLogRepo == nil {
return usage, nil
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
return usage, nil
}
start := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
totals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
return usage, nil
}
// addWindowStats 为 usage 数据添加窗口期统计 // addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存(1 分钟),与 API 缓存分离 // 使用独立缓存(1 分钟),与 API 缓存分离
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) { func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
...@@ -385,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn ...@@ -385,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
// Setup Token无法获取7d数据 // Setup Token无法获取7d数据
return info return info
} }
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
if limit <= 0 {
return nil
}
utilization := (float64(used) / float64(limit)) * 100
remainingSeconds := int(resetAt.Sub(now).Seconds())
if remainingSeconds < 0 {
remainingSeconds = 0
}
resetCopy := resetAt
return &UsageProgress{
Utilization: utilization,
ResetsAt: &resetCopy,
RemainingSeconds: remainingSeconds,
WindowStats: &WindowStats{
Requests: used,
Tokens: tokens,
Cost: cost,
},
}
}
...@@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
subscriptionType = SubscriptionTypeStandard subscriptionType = SubscriptionTypeStandard
} }
// 限额字段:0 和 nil 都表示"无限制"
dailyLimit := normalizeLimit(input.DailyLimitUSD)
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
group := &Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
...@@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
IsExclusive: input.IsExclusive, IsExclusive: input.IsExclusive,
Status: StatusActive, Status: StatusActive,
SubscriptionType: subscriptionType, SubscriptionType: subscriptionType,
DailyLimitUSD: input.DailyLimitUSD, DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: input.WeeklyLimitUSD, WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: input.MonthlyLimitUSD, MonthlyLimitUSD: monthlyLimit,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
...@@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil return group, nil
} }
// normalizeLimit 将 0 或负数转换为 nil(表示无限制)
func normalizeLimit(limit *float64) *float64 {
if limit == nil || *limit <= 0 {
return nil
}
return limit
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
...@@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SubscriptionType != "" { if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType group.SubscriptionType = input.SubscriptionType
} }
// 限额字段支持设置为nil(清除限额)或具体值 // 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
if input.DailyLimitUSD != nil { if input.DailyLimitUSD != nil {
group.DailyLimitUSD = input.DailyLimitUSD group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
} }
if input.WeeklyLimitUSD != nil { if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = input.WeeklyLimitUSD group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
} }
if input.MonthlyLimitUSD != nil { if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = input.MonthlyLimitUSD group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
} }
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
......
...@@ -25,7 +25,7 @@ const ( ...@@ -25,7 +25,7 @@ const (
antigravityRetryMaxDelay = 16 * time.Second antigravityRetryMaxDelay = 16 * time.Second
) )
// Antigravity 直接支持的模型 // Antigravity 直接支持的模型(精确匹配透传)
var antigravitySupportedModels = map[string]bool{ var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true, "claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true, "claude-sonnet-4-5": true,
...@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{ ...@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{
"gemini-3-flash": true, "gemini-3-flash": true,
"gemini-3-pro-low": true, "gemini-3-pro-low": true,
"gemini-3-pro-high": true, "gemini-3-pro-high": true,
"gemini-3-pro-preview": true,
"gemini-3-pro-image": true, "gemini-3-pro-image": true,
} }
// Antigravity 系统默认模型映射表(不支持 → 支持) // Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
var antigravityModelMapping = map[string]string{ // 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5", var antigravityPrefixMapping = []struct {
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5", prefix string
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking", target string
"claude-opus-4": "claude-opus-4-5-thinking", }{
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking", // 长前缀优先
"claude-haiku-4": "gemini-3-flash", {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
"claude-haiku-4-5": "gemini-3-flash", {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
"claude-3-haiku-20240307": "gemini-3-flash", {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
"claude-haiku-4-5-20251001": "gemini-3-flash", {"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx
// 生图模型:官方名 → Antigravity 内部名 {"claude-opus-4-5", "claude-opus-4-5-thinking"},
"gemini-3-pro-image-preview": "gemini-3-pro-image", {"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx
{"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "gemini-3-flash"},
{"claude-opus-4", "claude-opus-4-5-thinking"},
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
} }
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发 // AntigravityGatewayService 处理 Antigravity 平台的 API 转发
...@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider ...@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider
} }
// getMappedModel 获取映射后的模型名 // getMappedModel 获取映射后的模型名
// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string { func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
// 1. 优先使用账户级映射(复用现有方法 // 1. 账户级映射(用户自定义优先
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel { if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
return mapped return mapped
} }
// 2. 系统默认映射 // 2. 直接支持的模型透传
if mapped, ok := antigravityModelMapping[requestedModel]; ok { if antigravitySupportedModels[requestedModel] {
return mapped return requestedModel
} }
// 3. Gemini 模型透传 // 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview)
if strings.HasPrefix(requestedModel, "gemini-") { for _, pm := range antigravityPrefixMapping {
return requestedModel if strings.HasPrefix(requestedModel, pm.prefix) {
return pm.target
}
} }
// 4. Claude 前缀透传直接支持的模型 // 4. Gemini 模型透传(未匹配到前缀的 gemini 模型
if antigravitySupportedModels[requestedModel] { if strings.HasPrefix(requestedModel, "gemini-") {
return requestedModel return requestedModel
} }
...@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo ...@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
} }
// IsModelSupported 检查模型是否被支持 // IsModelSupported 检查模型是否被支持
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool { func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
// 直接支持的模型 return strings.HasPrefix(requestedModel, "claude-") ||
if antigravitySupportedModels[requestedModel] { strings.HasPrefix(requestedModel, "gemini-")
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射)
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
} }
// TestConnectionResult 测试连接结果 // TestConnectionResult 测试连接结果
...@@ -180,7 +172,7 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account ...@@ -180,7 +172,7 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
} }
// 发送请求 // 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL) resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil { if err != nil {
return nil, fmt.Errorf("请求失败: %w", err) return nil, fmt.Errorf("请求失败: %w", err)
} }
...@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err) return nil, fmt.Errorf("transform request: %w", err)
} }
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
truncated := string(bodyJSON)
if len(truncated) > 2000 {
truncated = truncated[:2000] + "..."
}
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
}
// 构建上游 action // 构建上游 action
action := "generateContent" action := "generateContent"
if claudeReq.Stream { if claudeReq.Stream {
...@@ -372,7 +373,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -372,7 +373,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, err return nil, err
} }
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil { if err != nil {
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err) log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
...@@ -515,7 +516,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ...@@ -515,7 +516,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, err return nil, err
} }
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil { if err != nil {
if attempt < antigravityMaxRetries { if attempt < antigravityMaxRetries {
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err) log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
......
...@@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { ...@@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
name: "系统映射 - claude-sonnet-4-5-20250929", name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929", requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil, accountMapping: nil,
expected: "claude-sonnet-4-5-thinking", expected: "claude-sonnet-4-5",
}, },
// 3. Gemini 透传 // 3. Gemini 透传
......
...@@ -191,7 +191,7 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc ...@@ -191,7 +191,7 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc
// isTokenExpired 检查 token 是否过期 // isTokenExpired 检查 token 是否过期
func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool { func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool {
expiresAt := parseAntigravityExpiresAt(account) expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil { if expiresAt == nil {
return false return false
} }
......
...@@ -55,7 +55,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -55,7 +55,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
} }
// 2. 如果即将过期则刷新 // 2. 如果即将过期则刷新
expiresAt := parseAntigravityExpiresAt(account) expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
if needsRefresh && p.tokenCache != nil { if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
...@@ -72,7 +72,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -72,7 +72,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if err == nil && fresh != nil { if err == nil && fresh != nil {
account = fresh account = fresh
} }
expiresAt = parseAntigravityExpiresAt(account) expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew { if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
if p.antigravityOAuthService == nil { if p.antigravityOAuthService == nil {
return "", errors.New("antigravity oauth service not configured") return "", errors.New("antigravity oauth service not configured")
...@@ -91,7 +91,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -91,7 +91,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr) log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
} }
expiresAt = parseAntigravityExpiresAt(account) expiresAt = account.GetCredentialAsTime("expires_at")
} }
} }
} }
...@@ -128,18 +128,3 @@ func antigravityTokenCacheKey(account *Account) string { ...@@ -128,18 +128,3 @@ func antigravityTokenCacheKey(account *Account) string {
} }
return "ag:account:" + strconv.FormatInt(account.ID, 10) return "ag:account:" + strconv.FormatInt(account.ID, 10)
} }
func parseAntigravityExpiresAt(account *Account) *time.Time {
raw := strings.TrimSpace(account.GetCredential("expires_at"))
if raw == "" {
return nil
}
if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 {
t := time.Unix(unixSec, 0)
return &t
}
if t, err := time.Parse(time.RFC3339, raw); err == nil {
return &t
}
return nil
}
...@@ -2,7 +2,7 @@ package service ...@@ -2,7 +2,7 @@ package service
import ( import (
"context" "context"
"strconv" "fmt"
"time" "time"
) )
...@@ -29,21 +29,22 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool { ...@@ -29,21 +29,22 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
} }
// NeedsRefresh 检查账户是否需要刷新 // NeedsRefresh 检查账户是否需要刷新
// Antigravity 使用固定的10分钟刷新窗口,忽略全局配置 // Antigravity 使用固定的15分钟刷新窗口,忽略全局配置
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool { func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
if !r.CanRefresh(account) { if !r.CanRefresh(account) {
return false return false
} }
expiresAtStr := account.GetCredential("expires_at") expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAtStr == "" { if expiresAt == nil {
return false return false
} }
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) timeUntilExpiry := time.Until(*expiresAt)
if err != nil { needsRefresh := timeUntilExpiry < antigravityRefreshWindow
return false if needsRefresh {
fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n",
account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow)
} }
expiryTime := time.Unix(expiresAt, 0) return needsRefresh
return time.Until(expiryTime) < antigravityRefreshWindow
} }
// Refresh 执行 token 刷新 // Refresh 执行 token 刷新
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
) )
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
......
...@@ -4,10 +4,12 @@ import ( ...@@ -4,10 +4,12 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"sync"
"sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
// 错误定义 // 错误定义
...@@ -27,6 +29,46 @@ type subscriptionCacheData struct { ...@@ -27,6 +29,46 @@ type subscriptionCacheData struct {
Version int64 Version int64
} }
// 缓存写入任务类型
type cacheWriteKind int
const (
cacheWriteSetBalance cacheWriteKind = iota
cacheWriteSetSubscription
cacheWriteUpdateSubscriptionUsage
cacheWriteDeductBalance
)
// 异步缓存写入工作池配置
//
// 性能优化说明:
// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题:
// 1. 每次请求创建新 goroutine,高并发下产生大量短生命周期 goroutine
// 2. 无法控制并发数量,可能导致 Redis 连接耗尽
// 3. goroutine 创建/销毁带来额外开销
//
// 新实现使用固定大小的工作池:
// 1. 预创建 10 个 worker goroutine,避免频繁创建销毁
// 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值
// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警
// 4. 统一超时控制,避免慢操作阻塞工作池
const (
cacheWriteWorkerCount = 10 // 工作协程数量
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
)
// cacheWriteTask 缓存写入任务
type cacheWriteTask struct {
kind cacheWriteKind
userID int64
groupID int64
balance float64
amount float64
subscriptionData *subscriptionCacheData
}
// BillingCacheService 计费缓存服务 // BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct { type BillingCacheService struct {
...@@ -34,16 +76,151 @@ type BillingCacheService struct { ...@@ -34,16 +76,151 @@ type BillingCacheService struct {
userRepo UserRepository userRepo UserRepository
subRepo UserSubscriptionRepository subRepo UserSubscriptionRepository
cfg *config.Config cfg *config.Config
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
cacheWriteStopOnce sync.Once
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount uint64
cacheWriteDropFullLastLog int64
cacheWriteDropClosedCount uint64
cacheWriteDropClosedLastLog int64
} }
// NewBillingCacheService 创建计费缓存服务 // NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService { func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
return &BillingCacheService{ svc := &BillingCacheService{
cache: cache, cache: cache,
userRepo: userRepo, userRepo: userRepo,
subRepo: subRepo, subRepo: subRepo,
cfg: cfg, cfg: cfg,
} }
svc.startCacheWriteWorkers()
return svc
}
// Stop 关闭缓存写入工作池
func (s *BillingCacheService) Stop() {
s.cacheWriteStopOnce.Do(func() {
if s.cacheWriteChan == nil {
return
}
close(s.cacheWriteChan)
s.cacheWriteWg.Wait()
s.cacheWriteChan = nil
})
}
func (s *BillingCacheService) startCacheWriteWorkers() {
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
for i := 0; i < cacheWriteWorkerCount; i++ {
s.cacheWriteWg.Add(1)
go s.cacheWriteWorker()
}
}
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
if s.cacheWriteChan == nil {
return false
}
defer func() {
if recovered := recover(); recovered != nil {
// 队列已关闭时可能触发 panic,记录后静默失败。
s.logCacheWriteDrop(task, "closed")
enqueued = false
}
}()
select {
case s.cacheWriteChan <- task:
return true
default:
// 队列满时不阻塞主流程,交由调用方决定是否同步回退。
s.logCacheWriteDrop(task, "full")
return false
}
}
func (s *BillingCacheService) cacheWriteWorker() {
defer s.cacheWriteWg.Done()
for task := range s.cacheWriteChan {
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
switch task.kind {
case cacheWriteSetBalance:
s.setBalanceCache(ctx, task.userID, task.balance)
case cacheWriteSetSubscription:
s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData)
case cacheWriteUpdateSubscriptionUsage:
if s.cache != nil {
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
}
}
case cacheWriteDeductBalance:
if s.cache != nil {
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err)
}
}
}
cancel()
}
}
// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。
func cacheWriteKindName(kind cacheWriteKind) string {
switch kind {
case cacheWriteSetBalance:
return "set_balance"
case cacheWriteSetSubscription:
return "set_subscription"
case cacheWriteUpdateSubscriptionUsage:
return "update_subscription_usage"
case cacheWriteDeductBalance:
return "deduct_balance"
default:
return "unknown"
}
}
// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。
func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) {
var (
countPtr *uint64
lastPtr *int64
)
switch reason {
case "full":
countPtr = &s.cacheWriteDropFullCount
lastPtr = &s.cacheWriteDropFullLastLog
case "closed":
countPtr = &s.cacheWriteDropClosedCount
lastPtr = &s.cacheWriteDropClosedLastLog
default:
return
}
atomic.AddUint64(countPtr, 1)
now := time.Now().UnixNano()
last := atomic.LoadInt64(lastPtr)
if now-last < int64(cacheWriteDropLogInterval) {
return
}
if !atomic.CompareAndSwapInt64(lastPtr, last, now) {
return
}
dropped := atomic.SwapUint64(countPtr, 0)
if dropped == 0 {
return
}
log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
reason,
dropped,
cacheWriteDropLogInterval,
cacheWriteKindName(task.kind),
task.userID,
task.groupID,
)
} }
// ============================================ // ============================================
...@@ -70,11 +247,11 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) ...@@ -70,11 +247,11 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
} }
// 异步建立缓存 // 异步建立缓存
go func() { _ = s.enqueueCacheWrite(cacheWriteTask{
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) kind: cacheWriteSetBalance,
defer cancel() userID: userID,
s.setBalanceCache(cacheCtx, userID, balance) balance: balance,
}() })
return balance, nil return balance, nil
} }
...@@ -98,7 +275,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, ...@@ -98,7 +275,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64,
} }
} }
// DeductBalanceCache 扣减余额缓存(步调用,用于扣费后更新缓存 // DeductBalanceCache 扣减余额缓存(步调用)
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error { func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
if s.cache == nil { if s.cache == nil {
return nil return nil
...@@ -106,6 +283,26 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int ...@@ -106,6 +283,26 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
return s.cache.DeductUserBalance(ctx, userID, amount) return s.cache.DeductUserBalance(ctx, userID, amount)
} }
// QueueDeductBalance 异步扣减余额缓存
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
if s.cache == nil {
return
}
// 队列满时同步回退,避免关键扣减被静默丢弃。
if s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteDeductBalance,
userID: userID,
amount: amount,
}) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
}
}
// InvalidateUserBalance 失效用户余额缓存 // InvalidateUserBalance 失效用户余额缓存
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error { func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
if s.cache == nil { if s.cache == nil {
...@@ -141,11 +338,12 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, ...@@ -141,11 +338,12 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
} }
// 异步建立缓存 // 异步建立缓存
go func() { _ = s.enqueueCacheWrite(cacheWriteTask{
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) kind: cacheWriteSetSubscription,
defer cancel() userID: userID,
s.setSubscriptionCache(cacheCtx, userID, groupID, data) groupID: groupID,
}() subscriptionData: data,
})
return data, nil return data, nil
} }
...@@ -199,7 +397,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, ...@@ -199,7 +397,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID,
} }
} }
// UpdateSubscriptionUsage 更新订阅用量缓存(步调用,用于扣费后更新缓存 // UpdateSubscriptionUsage 更新订阅用量缓存(步调用)
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error { func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
if s.cache == nil { if s.cache == nil {
return nil return nil
...@@ -207,6 +405,27 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI ...@@ -207,6 +405,27 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD) return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
} }
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
if s.cache == nil {
return
}
// 队列满时同步回退,确保订阅用量及时更新。
if s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteUpdateSubscriptionUsage,
userID: userID,
groupID: groupID,
amount: costUSD,
}) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
}
}
// InvalidateSubscription 失效指定订阅缓存 // InvalidateSubscription 失效指定订阅缓存
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error { func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
if s.cache == nil { if s.cache == nil {
......
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type billingCacheWorkerStub struct {
balanceUpdates int64
subscriptionUpdates int64
}
func (b *billingCacheWorkerStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
atomic.AddInt64(&b.balanceUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
atomic.AddInt64(&b.balanceUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
return nil
}
func (b *billingCacheWorkerStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
return nil, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
atomic.AddInt64(&b.subscriptionUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
atomic.AddInt64(&b.subscriptionUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
return nil
}
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
start := time.Now()
for i := 0; i < cacheWriteBufferSize*2; i++ {
svc.QueueDeductBalance(1, 1)
}
require.Less(t, time.Since(start), 2*time.Second)
svc.QueueUpdateSubscriptionUsage(1, 2, 1.5)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&cache.balanceUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
}
...@@ -9,24 +9,35 @@ import ( ...@@ -9,24 +9,35 @@ import (
"time" "time"
) )
// ConcurrencyCache defines cache operations for concurrency service // ConcurrencyCache 定义并发控制的缓存接口
// Uses independent keys per request slot with native Redis TTL for automatic cleanup // 使用有序集合存储槽位,按时间戳清理过期条目
type ConcurrencyCache interface { type ConcurrencyCache interface {
// Account slot management - each slot is a separate key with independent TTL // 账号槽位管理
// Key format: concurrency:account:{accountID}:{requestID} // 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
// User slot management - each slot is a separate key with independent TTL // 账号等待队列(账号级)
// Key format: concurrency:user:{userID}:{requestID} IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
// 用户槽位管理
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
GetUserConcurrency(ctx context.Context, userID int64) (int, error) GetUserConcurrency(ctx context.Context, userID int64) (int, error)
// Wait queue - uses counter with TTL set only on creation // 等待队列计数(只在首次创建时设置 TTL)
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error DecrementWaitCount(ctx context.Context, userID int64) error
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
} }
// generateRequestID generates a unique request ID for concurrency slot tracking // generateRequestID generates a unique request ID for concurrency slot tracking
...@@ -61,6 +72,18 @@ type AcquireResult struct { ...@@ -61,6 +72,18 @@ type AcquireResult struct {
ReleaseFunc func() // Must be called when done (typically via defer) ReleaseFunc func() // Must be called when done (typically via defer)
} }
type AccountWithConcurrency struct {
ID int64
MaxConcurrency int
}
type AccountLoadInfo struct {
AccountID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account. // AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout. // If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes. // Returns a release function that MUST be called when the request completes.
...@@ -177,6 +200,42 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6 ...@@ -177,6 +200,42 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
} }
} }
// IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil {
return true, nil
}
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
if err != nil {
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
return true, nil
}
return result, nil
}
// DecrementAccountWaitCount decrements the wait queue counter for an account.
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
if s.cache == nil {
return
}
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
}
}
// GetAccountWaitingCount gets current wait queue count for an account.
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetAccountWaitingCount(ctx, accountID)
}
// CalculateMaxWait calculates the maximum wait queue size for a user // CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots // maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int { func CalculateMaxWait(userConcurrency int) int {
...@@ -186,6 +245,57 @@ func CalculateMaxWait(userConcurrency int) int { ...@@ -186,6 +245,57 @@ func CalculateMaxWait(userConcurrency int) int {
return userConcurrency + defaultExtraWaitSlots return userConcurrency + defaultExtraWaitSlots
} }
// GetAccountsLoadBatch returns load info for multiple accounts.
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if s.cache == nil {
return map[int64]*AccountLoadInfo{}, nil
}
return s.cache.GetAccountsLoadBatch(ctx, accounts)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
if s.cache == nil {
return nil
}
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
}
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
return
}
runCleanup := func() {
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
accounts, err := accountRepo.ListSchedulable(listCtx)
cancel()
if err != nil {
log.Printf("Warning: list schedulable accounts failed: %v", err)
return
}
for _, account := range accounts {
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
accountCancel()
if err != nil {
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
}
}
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
runCleanup()
for range ticker.C {
runCleanup()
}
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count // Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
......
...@@ -12,6 +12,8 @@ import ( ...@@ -12,6 +12,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
) )
type CRSSyncService struct { type CRSSyncService struct {
...@@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ...@@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
return nil, errors.New("username and password are required") return nil, errors.New("username and password are required")
} }
client := &http.Client{Timeout: 20 * time.Second} client, err := httpclient.GetClient(httpclient.Options{
Timeout: 20 * time.Second,
})
if err != nil {
client = &http.Client{Timeout: 20 * time.Second}
}
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password) adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
if err != nil { if err != nil {
......
...@@ -91,6 +91,9 @@ const ( ...@@ -91,6 +91,9 @@ const (
// 管理员 API Key // 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
// Gemini 配额策略(JSON)
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
) )
// Admin API Key prefix (distinct from user "sk-" keys) // Admin API Key prefix (distinct from user "sk-" keys)
......
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
"strconv" "strconv"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
var ( var (
......
...@@ -261,6 +261,34 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t ...@@ -261,6 +261,34 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户") require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
} }
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
}
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户 // TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) { func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
ctx := context.Background() ctx := context.Background()
...@@ -576,6 +604,32 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) { ...@@ -576,6 +604,32 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
})
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) { t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
...@@ -783,3 +837,160 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) { ...@@ -783,3 +837,160 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
}) })
} }
} }
// mockConcurrencyService for testing
type mockConcurrencyService struct {
accountLoads map[int64]*AccountLoadInfo
accountWaitCounts map[int64]int
acquireResults map[int64]bool
}
func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if m.accountLoads == nil {
return map[int64]*AccountLoadInfo{}, nil
}
result := make(map[int64]*AccountLoadInfo)
for _, acc := range accounts {
if load, ok := m.accountLoads[acc.ID]; ok {
result[acc.ID] = load
} else {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
}
return result, nil
}
func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if m.accountWaitCounts == nil {
return 0, nil
}
return m.accountWaitCounts[accountID], nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // No concurrency service
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("排除账号-不选择被排除的账号", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
})
t.Run("无可用账号-返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
})
}
package service
import (
"encoding/json"
"fmt"
)
// ParsedRequest 保存网关请求的预解析结果
//
// 性能优化说明:
// 原实现在多个位置重复解析请求体(Handler、Service 各解析一次):
// 1. gateway_handler.go 解析获取 model 和 stream
// 2. gateway_service.go 再次解析获取 system、messages、metadata
// 3. GenerateSessionHash 又一次解析获取会话哈希所需字段
//
// 新实现一次解析,多处复用:
// 1. 在 Handler 层统一调用 ParseGatewayRequest 一次性解析
// 2. 将解析结果 ParsedRequest 传递给 Service 层
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
type ParsedRequest struct {
Body []byte // 原始请求体(保留用于转发)
Model string // 请求的模型名称
Stream bool // 是否为流式请求
MetadataUserID string // metadata.user_id(用于会话亲和)
System any // system 字段内容
Messages []any // messages 数组
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
}
// ParseGatewayRequest 解析网关请求体并返回结构化结果
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return nil, err
}
parsed := &ParsedRequest{
Body: body,
}
if rawModel, exists := req["model"]; exists {
model, ok := rawModel.(string)
if !ok {
return nil, fmt.Errorf("invalid model field type")
}
parsed.Model = model
}
if rawStream, exists := req["stream"]; exists {
stream, ok := rawStream.(bool)
if !ok {
return nil, fmt.Errorf("invalid stream field type")
}
parsed.Stream = stream
}
if metadata, ok := req["metadata"].(map[string]any); ok {
if userID, ok := metadata["user_id"].(string); ok {
parsed.MetadataUserID = userID
}
}
// system 字段只要存在就视为显式提供(即使为 null),
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
return parsed, nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment