Commit 6901b64f authored by cyhhao's avatar cyhhao
Browse files

merge: sync upstream changes

parents 32c47b15 dae0d532
...@@ -81,6 +81,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -81,6 +81,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ops.PUT("/alert-rules/:id", h.Admin.Ops.UpdateAlertRule) ops.PUT("/alert-rules/:id", h.Admin.Ops.UpdateAlertRule)
ops.DELETE("/alert-rules/:id", h.Admin.Ops.DeleteAlertRule) ops.DELETE("/alert-rules/:id", h.Admin.Ops.DeleteAlertRule)
ops.GET("/alert-events", h.Admin.Ops.ListAlertEvents) ops.GET("/alert-events", h.Admin.Ops.ListAlertEvents)
ops.GET("/alert-events/:id", h.Admin.Ops.GetAlertEvent)
ops.PUT("/alert-events/:id/status", h.Admin.Ops.UpdateAlertEventStatus)
ops.POST("/alert-silences", h.Admin.Ops.CreateAlertSilence)
// Email notification config (DB-backed) // Email notification config (DB-backed)
ops.GET("/email-notification/config", h.Admin.Ops.GetEmailNotificationConfig) ops.GET("/email-notification/config", h.Admin.Ops.GetEmailNotificationConfig)
...@@ -110,10 +113,26 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -110,10 +113,26 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ws.GET("/qps", h.Admin.Ops.QPSWSHandler) ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
} }
// Error logs (MVP-1) // Error logs (legacy)
ops.GET("/errors", h.Admin.Ops.GetErrorLogs) ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
ops.GET("/errors/:id", h.Admin.Ops.GetErrorLogByID) ops.GET("/errors/:id", h.Admin.Ops.GetErrorLogByID)
ops.GET("/errors/:id/retries", h.Admin.Ops.ListRetryAttempts)
ops.POST("/errors/:id/retry", h.Admin.Ops.RetryErrorRequest) ops.POST("/errors/:id/retry", h.Admin.Ops.RetryErrorRequest)
ops.PUT("/errors/:id/resolve", h.Admin.Ops.UpdateErrorResolution)
// Request errors (client-visible failures)
ops.GET("/request-errors", h.Admin.Ops.ListRequestErrors)
ops.GET("/request-errors/:id", h.Admin.Ops.GetRequestError)
ops.GET("/request-errors/:id/upstream-errors", h.Admin.Ops.ListRequestErrorUpstreamErrors)
ops.POST("/request-errors/:id/retry-client", h.Admin.Ops.RetryRequestErrorClient)
ops.POST("/request-errors/:id/upstream-errors/:idx/retry", h.Admin.Ops.RetryRequestErrorUpstreamEvent)
ops.PUT("/request-errors/:id/resolve", h.Admin.Ops.ResolveRequestError)
// Upstream errors (independent upstream failures)
ops.GET("/upstream-errors", h.Admin.Ops.ListUpstreamErrors)
ops.GET("/upstream-errors/:id", h.Admin.Ops.GetUpstreamError)
ops.POST("/upstream-errors/:id/retry", h.Admin.Ops.RetryUpstreamError)
ops.PUT("/upstream-errors/:id/resolve", h.Admin.Ops.ResolveUpstreamError)
// Request drilldown (success + error) // Request drilldown (success + error)
ops.GET("/requests", h.Admin.Ops.ListRequestDetails) ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
...@@ -250,6 +269,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ...@@ -250,6 +269,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies.POST("/:id/test", h.Admin.Proxy.Test) proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats) proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts) proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate) proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
} }
} }
......
...@@ -9,16 +9,19 @@ import ( ...@@ -9,16 +9,19 @@ import (
) )
type Account struct { type Account struct {
ID int64 ID int64
Name string Name string
Notes *string Notes *string
Platform string Platform string
Type string Type string
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
ProxyID *int64 ProxyID *int64
Concurrency int Concurrency int
Priority int Priority int
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
RateMultiplier *float64
Status string Status string
ErrorMessage string ErrorMessage string
LastUsedAt *time.Time LastUsedAt *time.Time
...@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool { ...@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool {
return a.Status == StatusActive return a.Status == StatusActive
} }
// BillingRateMultiplier 返回账号计费倍率。
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
// - 允许 0,表示该账号计费为 0
// - 负数属于非法数据,出于安全考虑按 1.0 处理
func (a *Account) BillingRateMultiplier() float64 {
if a == nil || a.RateMultiplier == nil {
return 1.0
}
if *a.RateMultiplier < 0 {
return 1.0
}
return *a.RateMultiplier
}
func (a *Account) IsSchedulable() bool { func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable { if !a.IsActive() || !a.Schedulable {
return false return false
...@@ -556,3 +573,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool { ...@@ -556,3 +573,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
} }
return false return false
} }
// WindowCostSchedulability 窗口费用调度状态
type WindowCostSchedulability int
const (
// WindowCostSchedulable 可正常调度
WindowCostSchedulable WindowCostSchedulability = iota
// WindowCostStickyOnly 仅允许粘性会话
WindowCostStickyOnly
// WindowCostNotSchedulable 完全不可调度
WindowCostNotSchedulable
)
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["window_cost_limit"]; ok {
return parseExtraFloat64(v)
}
return 0
}
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
// 默认值为 10
func (a *Account) GetWindowCostStickyReserve() float64 {
if a.Extra == nil {
return 10.0
}
if v, ok := a.Extra["window_cost_sticky_reserve"]; ok {
val := parseExtraFloat64(v)
if val > 0 {
return val
}
}
return 10.0
}
// GetMaxSessions 获取最大并发会话数
// 返回 0 表示未启用
func (a *Account) GetMaxSessions() int {
if a.Extra == nil {
return 0
}
if v, ok := a.Extra["max_sessions"]; ok {
return parseExtraInt(v)
}
return 0
}
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
// 默认值为 5 分钟
func (a *Account) GetSessionIdleTimeoutMinutes() int {
if a.Extra == nil {
return 5
}
if v, ok := a.Extra["session_idle_timeout_minutes"]; ok {
val := parseExtraInt(v)
if val > 0 {
return val
}
}
return 5
}
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) WindowCostSchedulability {
limit := a.GetWindowCostLimit()
if limit <= 0 {
return WindowCostSchedulable
}
if currentWindowCost < limit {
return WindowCostSchedulable
}
stickyReserve := a.GetWindowCostStickyReserve()
if currentWindowCost < limit+stickyReserve {
return WindowCostStickyOnly
}
return WindowCostNotSchedulable
}
// parseExtraFloat64 从 extra 字段解析 float64 值
func parseExtraFloat64(value any) float64 {
switch v := value.(type) {
case float64:
return v
case float32:
return float64(v)
case int:
return float64(v)
case int64:
return float64(v)
case json.Number:
if f, err := v.Float64(); err == nil {
return f
}
case string:
if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil {
return f
}
}
return 0
}
// parseExtraInt 从 extra 字段解析 int 值
func parseExtraInt(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return 0
}
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
var a Account
require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
require.Nil(t, a.RateMultiplier)
require.Equal(t, 1.0, a.BillingRateMultiplier())
}
func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
v := 0.0
a := Account{RateMultiplier: &v}
require.Equal(t, 0.0, a.BillingRateMultiplier())
}
func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
v := -1.0
a := Account{RateMultiplier: &v}
require.Equal(t, 1.0, a.BillingRateMultiplier())
}
...@@ -50,11 +50,13 @@ type AccountRepository interface { ...@@ -50,11 +50,13 @@ type AccountRepository interface {
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error ClearRateLimit(ctx context.Context, id int64) error
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
ClearModelRateLimits(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
...@@ -63,14 +65,15 @@ type AccountRepository interface { ...@@ -63,14 +65,15 @@ type AccountRepository interface {
// AccountBulkUpdate describes the fields that can be updated in a bulk operation. // AccountBulkUpdate describes the fields that can be updated in a bulk operation.
// Nil pointers mean "do not change". // Nil pointers mean "do not change".
type AccountBulkUpdate struct { type AccountBulkUpdate struct {
Name *string Name *string
ProxyID *int64 ProxyID *int64
Concurrency *int Concurrency *int
Priority *int Priority *int
Status *string RateMultiplier *float64
Schedulable *bool Status *string
Credentials map[string]any Schedulable *bool
Extra map[string]any Credentials map[string]any
Extra map[string]any
} }
// CreateAccountRequest 创建账号请求 // CreateAccountRequest 创建账号请求
......
...@@ -143,6 +143,10 @@ func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id ...@@ -143,6 +143,10 @@ func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id
panic("unexpected SetAntigravityQuotaScopeLimit call") panic("unexpected SetAntigravityQuotaScopeLimit call")
} }
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
panic("unexpected SetModelRateLimit call")
}
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
panic("unexpected SetOverloaded call") panic("unexpected SetOverloaded call")
} }
...@@ -163,6 +167,10 @@ func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id in ...@@ -163,6 +167,10 @@ func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id in
panic("unexpected ClearAntigravityQuotaScopes call") panic("unexpected ClearAntigravityQuotaScopes call")
} }
func (s *accountRepoStub) ClearModelRateLimits(ctx context.Context, id int64) error {
panic("unexpected ClearModelRateLimits call")
}
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
panic("unexpected UpdateSessionWindow call") panic("unexpected UpdateSessionWindow call")
} }
......
...@@ -32,8 +32,8 @@ type UsageLogRepository interface { ...@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats // Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
...@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache { ...@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache {
} }
// WindowStats 窗口期统计 // WindowStats 窗口期统计
//
// cost: 账号口径费用(total_cost * account_rate_multiplier)
// standard_cost: 标准费用(total_cost,不含倍率)
// user_cost: 用户/API Key 口径费用(actual_cost,受分组倍率影响)
type WindowStats struct { type WindowStats struct {
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"` Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"`
StandardCost float64 `json:"standard_cost"`
UserCost float64 `json:"user_cost"`
} }
// UsageProgress 使用量进度 // UsageProgress 使用量进度
...@@ -266,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou ...@@ -266,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
} }
dayStart := geminiDailyWindowStart(now) dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID) stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err) return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
} }
...@@ -288,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou ...@@ -288,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m) // Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute) minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute) minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID) minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err) return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
} }
...@@ -377,9 +383,11 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou ...@@ -377,9 +383,11 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
} }
windowStats = &WindowStats{ windowStats = &WindowStats{
Requests: stats.Requests, Requests: stats.Requests,
Tokens: stats.Tokens, Tokens: stats.Tokens,
Cost: stats.Cost, Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
} }
// 缓存窗口统计(1 分钟) // 缓存窗口统计(1 分钟)
...@@ -403,9 +411,11 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64 ...@@ -403,9 +411,11 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
} }
return &WindowStats{ return &WindowStats{
Requests: stats.Requests, Requests: stats.Requests,
Tokens: stats.Tokens, Tokens: stats.Tokens,
Cost: stats.Cost, Cost: stats.Cost,
StandardCost: stats.StandardCost,
UserCost: stats.UserCost,
}, nil }, nil
} }
...@@ -565,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64 ...@@ -565,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
}, },
} }
} }
// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
// 用于账号列表页面显示当前窗口费用
func (s *AccountUsageService) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
return s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime)
}
...@@ -54,7 +54,8 @@ type AdminService interface { ...@@ -54,7 +54,8 @@ type AdminService interface {
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
DeleteProxy(ctx context.Context, id int64) error DeleteProxy(ctx context.Context, id int64) error
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error)
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
...@@ -105,6 +106,9 @@ type CreateGroupInput struct { ...@@ -105,6 +106,9 @@ type CreateGroupInput struct {
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID FallbackGroupID *int64 // 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled bool // 是否启用模型路由
} }
type UpdateGroupInput struct { type UpdateGroupInput struct {
...@@ -124,6 +128,9 @@ type UpdateGroupInput struct { ...@@ -124,6 +128,9 @@ type UpdateGroupInput struct {
ImagePrice4K *float64 ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID FallbackGroupID *int64 // 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting map[string][]int64
ModelRoutingEnabled *bool // 是否启用模型路由
} }
type CreateAccountInput struct { type CreateAccountInput struct {
...@@ -136,6 +143,7 @@ type CreateAccountInput struct { ...@@ -136,6 +143,7 @@ type CreateAccountInput struct {
ProxyID *int64 ProxyID *int64
Concurrency int Concurrency int
Priority int Priority int
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
GroupIDs []int64 GroupIDs []int64
ExpiresAt *int64 ExpiresAt *int64
AutoPauseOnExpired *bool AutoPauseOnExpired *bool
...@@ -151,8 +159,9 @@ type UpdateAccountInput struct { ...@@ -151,8 +159,9 @@ type UpdateAccountInput struct {
Credentials map[string]any Credentials map[string]any
Extra map[string]any Extra map[string]any
ProxyID *int64 ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0" Priority *int // 使用指针区分"未提供"和"设置为0"
RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
Status string Status string
GroupIDs *[]int64 GroupIDs *[]int64
ExpiresAt *int64 ExpiresAt *int64
...@@ -162,16 +171,17 @@ type UpdateAccountInput struct { ...@@ -162,16 +171,17 @@ type UpdateAccountInput struct {
// BulkUpdateAccountsInput describes the payload for bulk updating accounts. // BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct { type BulkUpdateAccountsInput struct {
AccountIDs []int64 AccountIDs []int64
Name string Name string
ProxyID *int64 ProxyID *int64
Concurrency *int Concurrency *int
Priority *int Priority *int
Status string RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
Schedulable *bool Status string
GroupIDs *[]int64 Schedulable *bool
Credentials map[string]any GroupIDs *[]int64
Extra map[string]any Credentials map[string]any
Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups. // SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk. // This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool SkipMixedChannelCheck bool
...@@ -220,23 +230,35 @@ type GenerateRedeemCodesInput struct { ...@@ -220,23 +230,35 @@ type GenerateRedeemCodesInput struct {
ValidityDays int // 订阅类型专用:有效天数 ValidityDays int // 订阅类型专用:有效天数
} }
type ProxyBatchDeleteResult struct {
DeletedIDs []int64 `json:"deleted_ids"`
Skipped []ProxyBatchDeleteSkipped `json:"skipped"`
}
type ProxyBatchDeleteSkipped struct {
ID int64 `json:"id"`
Reason string `json:"reason"`
}
// ProxyTestResult represents the result of testing a proxy // ProxyTestResult represents the result of testing a proxy
type ProxyTestResult struct { type ProxyTestResult struct {
Success bool `json:"success"` Success bool `json:"success"`
Message string `json:"message"` Message string `json:"message"`
LatencyMs int64 `json:"latency_ms,omitempty"` LatencyMs int64 `json:"latency_ms,omitempty"`
IPAddress string `json:"ip_address,omitempty"` IPAddress string `json:"ip_address,omitempty"`
City string `json:"city,omitempty"` City string `json:"city,omitempty"`
Region string `json:"region,omitempty"` Region string `json:"region,omitempty"`
Country string `json:"country,omitempty"` Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
} }
// ProxyExitInfo represents proxy exit information from ipinfo.io // ProxyExitInfo represents proxy exit information from ip-api.com
type ProxyExitInfo struct { type ProxyExitInfo struct {
IP string IP string
City string City string
Region string Region string
Country string Country string
CountryCode string
} }
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information // ProxyExitInfoProber tests proxy connectivity and retrieves exit information
...@@ -254,6 +276,7 @@ type adminServiceImpl struct { ...@@ -254,6 +276,7 @@ type adminServiceImpl struct {
redeemCodeRepo RedeemCodeRepository redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
authCacheInvalidator APIKeyAuthCacheInvalidator authCacheInvalidator APIKeyAuthCacheInvalidator
} }
...@@ -267,6 +290,7 @@ func NewAdminService( ...@@ -267,6 +290,7 @@ func NewAdminService(
redeemCodeRepo RedeemCodeRepository, redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber, proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
authCacheInvalidator APIKeyAuthCacheInvalidator, authCacheInvalidator APIKeyAuthCacheInvalidator,
) AdminService { ) AdminService {
return &adminServiceImpl{ return &adminServiceImpl{
...@@ -278,6 +302,7 @@ func NewAdminService( ...@@ -278,6 +302,7 @@ func NewAdminService(
redeemCodeRepo: redeemCodeRepo, redeemCodeRepo: redeemCodeRepo,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
proxyProber: proxyProber, proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
authCacheInvalidator: authCacheInvalidator, authCacheInvalidator: authCacheInvalidator,
} }
} }
...@@ -562,6 +587,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ...@@ -562,6 +587,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice4K: imagePrice4K, ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly, ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID, FallbackGroupID: input.FallbackGroupID,
ModelRouting: input.ModelRouting,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
...@@ -690,6 +716,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd ...@@ -690,6 +716,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
} }
} }
// 模型路由配置
if input.ModelRouting != nil {
group.ModelRouting = input.ModelRouting
}
if input.ModelRoutingEnabled != nil {
group.ModelRoutingEnabled = *input.ModelRoutingEnabled
}
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err return nil, err
} }
...@@ -817,6 +851,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou ...@@ -817,6 +851,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
} else { } else {
account.AutoPauseOnExpired = true account.AutoPauseOnExpired = true
} }
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
account.RateMultiplier = input.RateMultiplier
}
if err := s.accountRepo.Create(ctx, account); err != nil { if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err return nil, err
} }
...@@ -869,6 +909,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U ...@@ -869,6 +909,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Priority != nil { if input.Priority != nil {
account.Priority = *input.Priority account.Priority = *input.Priority
} }
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
account.RateMultiplier = input.RateMultiplier
}
if input.Status != "" { if input.Status != "" {
account.Status = input.Status account.Status = input.Status
} }
...@@ -942,6 +988,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -942,6 +988,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
} }
} }
if input.RateMultiplier != nil {
if *input.RateMultiplier < 0 {
return nil, errors.New("rate_multiplier must be >= 0")
}
}
// Prepare bulk updates for columns and JSONB fields. // Prepare bulk updates for columns and JSONB fields.
repoUpdates := AccountBulkUpdate{ repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials, Credentials: input.Credentials,
...@@ -959,6 +1011,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp ...@@ -959,6 +1011,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.Priority != nil { if input.Priority != nil {
repoUpdates.Priority = input.Priority repoUpdates.Priority = input.Priority
} }
if input.RateMultiplier != nil {
repoUpdates.RateMultiplier = input.RateMultiplier
}
if input.Status != "" { if input.Status != "" {
repoUpdates.Status = &input.Status repoUpdates.Status = &input.Status
} }
...@@ -1069,6 +1124,7 @@ func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page ...@@ -1069,6 +1124,7 @@ func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
s.attachProxyLatency(ctx, proxies)
return proxies, result.Total, nil return proxies, result.Total, nil
} }
...@@ -1077,7 +1133,12 @@ func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) { ...@@ -1077,7 +1133,12 @@ func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
} }
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
return s.proxyRepo.ListActiveWithAccountCount(ctx) proxies, err := s.proxyRepo.ListActiveWithAccountCount(ctx)
if err != nil {
return nil, err
}
s.attachProxyLatency(ctx, proxies)
return proxies, nil
} }
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) { func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
...@@ -1097,6 +1158,8 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn ...@@ -1097,6 +1158,8 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
if err := s.proxyRepo.Create(ctx, proxy); err != nil { if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err return nil, err
} }
// Probe latency asynchronously so creation isn't blocked by network timeout.
go s.probeProxyLatency(context.Background(), proxy)
return proxy, nil return proxy, nil
} }
...@@ -1135,12 +1198,53 @@ func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *Upd ...@@ -1135,12 +1198,53 @@ func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *Upd
} }
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error { func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
if err != nil {
return err
}
if count > 0 {
return ErrProxyInUse
}
return s.proxyRepo.Delete(ctx, id) return s.proxyRepo.Delete(ctx, id)
} }
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) { func (s *adminServiceImpl) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) {
// Return mock data for now - would need a dedicated repository method result := &ProxyBatchDeleteResult{}
return []Account{}, 0, nil if len(ids) == 0 {
return result, nil
}
for _, id := range ids {
count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
if err != nil {
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
ID: id,
Reason: err.Error(),
})
continue
}
if count > 0 {
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
ID: id,
Reason: ErrProxyInUse.Error(),
})
continue
}
if err := s.proxyRepo.Delete(ctx, id); err != nil {
result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
ID: id,
Reason: err.Error(),
})
continue
}
result.DeletedIDs = append(result.DeletedIDs, id)
}
return result, nil
}
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
return s.proxyRepo.ListAccountSummariesByProxyID(ctx, proxyID)
} }
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) { func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
...@@ -1240,23 +1344,69 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR ...@@ -1240,23 +1344,69 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
proxyURL := proxy.URL() proxyURL := proxy.URL()
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
if err != nil { if err != nil {
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
Success: false,
Message: err.Error(),
UpdatedAt: time.Now(),
})
return &ProxyTestResult{ return &ProxyTestResult{
Success: false, Success: false,
Message: err.Error(), Message: err.Error(),
}, nil }, nil
} }
latency := latencyMs
s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
Success: true,
LatencyMs: &latency,
Message: "Proxy is accessible",
IPAddress: exitInfo.IP,
Country: exitInfo.Country,
CountryCode: exitInfo.CountryCode,
Region: exitInfo.Region,
City: exitInfo.City,
UpdatedAt: time.Now(),
})
return &ProxyTestResult{ return &ProxyTestResult{
Success: true, Success: true,
Message: "Proxy is accessible", Message: "Proxy is accessible",
LatencyMs: latencyMs, LatencyMs: latencyMs,
IPAddress: exitInfo.IP, IPAddress: exitInfo.IP,
City: exitInfo.City, City: exitInfo.City,
Region: exitInfo.Region, Region: exitInfo.Region,
Country: exitInfo.Country, Country: exitInfo.Country,
CountryCode: exitInfo.CountryCode,
}, nil }, nil
} }
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
if s.proxyProber == nil || proxy == nil {
return
}
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL())
if err != nil {
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
Success: false,
Message: err.Error(),
UpdatedAt: time.Now(),
})
return
}
latency := latencyMs
s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
Success: true,
LatencyMs: &latency,
Message: "Proxy is accessible",
IPAddress: exitInfo.IP,
Country: exitInfo.Country,
CountryCode: exitInfo.CountryCode,
Region: exitInfo.Region,
City: exitInfo.City,
UpdatedAt: time.Now(),
})
}
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic) // checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
// 如果存在混合,返回错误提示用户确认 // 如果存在混合,返回错误提示用户确认
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
...@@ -1306,6 +1456,51 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc ...@@ -1306,6 +1456,51 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil return nil
} }
func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
if s.proxyLatencyCache == nil || len(proxies) == 0 {
return
}
ids := make([]int64, 0, len(proxies))
for i := range proxies {
ids = append(ids, proxies[i].ID)
}
latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids)
if err != nil {
log.Printf("Warning: load proxy latency cache failed: %v", err)
return
}
for i := range proxies {
info := latencies[proxies[i].ID]
if info == nil {
continue
}
if info.Success {
proxies[i].LatencyStatus = "success"
proxies[i].LatencyMs = info.LatencyMs
} else {
proxies[i].LatencyStatus = "failed"
}
proxies[i].LatencyMessage = info.Message
proxies[i].IPAddress = info.IPAddress
proxies[i].Country = info.Country
proxies[i].CountryCode = info.CountryCode
proxies[i].Region = info.Region
proxies[i].City = info.City
}
}
func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) {
if s.proxyLatencyCache == nil || info == nil {
return
}
if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
log.Printf("Warning: store proxy latency cache failed: %v", err)
}
}
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识 // getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func getAccountPlatform(accountPlatform string) string { func getAccountPlatform(accountPlatform string) string {
switch strings.ToLower(strings.TrimSpace(accountPlatform)) { switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
......
...@@ -12,9 +12,9 @@ import ( ...@@ -12,9 +12,9 @@ import (
type accountRepoStubForBulkUpdate struct { type accountRepoStubForBulkUpdate struct {
accountRepoStub accountRepoStub
bulkUpdateErr error bulkUpdateErr error
bulkUpdateIDs []int64 bulkUpdateIDs []int64
bindGroupErrByID map[int64]error bindGroupErrByID map[int64]error
} }
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
......
...@@ -153,8 +153,10 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI ...@@ -153,8 +153,10 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
} }
type proxyRepoStub struct { type proxyRepoStub struct {
deleteErr error deleteErr error
deletedIDs []int64 countErr error
accountCount int64
deletedIDs []int64
} }
func (s *proxyRepoStub) Create(ctx context.Context, proxy *Proxy) error { func (s *proxyRepoStub) Create(ctx context.Context, proxy *Proxy) error {
...@@ -199,7 +201,14 @@ func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, p ...@@ -199,7 +201,14 @@ func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, p
} }
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
panic("unexpected CountAccountsByProxyID call") if s.countErr != nil {
return 0, s.countErr
}
return s.accountCount, nil
}
func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
panic("unexpected ListAccountSummariesByProxyID call")
} }
type redeemRepoStub struct { type redeemRepoStub struct {
...@@ -409,6 +418,15 @@ func TestAdminService_DeleteProxy_Idempotent(t *testing.T) { ...@@ -409,6 +418,15 @@ func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
require.Equal(t, []int64{404}, repo.deletedIDs) require.Equal(t, []int64{404}, repo.deletedIDs)
} }
func TestAdminService_DeleteProxy_InUse(t *testing.T) {
repo := &proxyRepoStub{accountCount: 2}
svc := &adminServiceImpl{proxyRepo: repo}
err := svc.DeleteProxy(context.Background(), 77)
require.ErrorIs(t, err, ErrProxyInUse)
require.Empty(t, repo.deletedIDs)
}
func TestAdminService_DeleteProxy_Error(t *testing.T) { func TestAdminService_DeleteProxy_Error(t *testing.T) {
deleteErr := errors.New("delete failed") deleteErr := errors.New("delete failed")
repo := &proxyRepoStub{deleteErr: deleteErr} repo := &proxyRepoStub{deleteErr: deleteErr}
......
...@@ -564,6 +564,10 @@ urlFallbackLoop: ...@@ -564,6 +564,10 @@ urlFallbackLoop:
} }
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody) upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
// Capture upstream request body for ops retry of this attempt.
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(geminiBody))
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -574,6 +578,7 @@ urlFallbackLoop: ...@@ -574,6 +578,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
Kind: "request_error", Kind: "request_error",
Message: safeErr, Message: safeErr,
...@@ -615,6 +620,7 @@ urlFallbackLoop: ...@@ -615,6 +620,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry", Kind: "retry",
...@@ -645,6 +651,7 @@ urlFallbackLoop: ...@@ -645,6 +651,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry", Kind: "retry",
...@@ -697,6 +704,7 @@ urlFallbackLoop: ...@@ -697,6 +704,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "signature_error", Kind: "signature_error",
...@@ -740,6 +748,7 @@ urlFallbackLoop: ...@@ -740,6 +748,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
Kind: "signature_retry_request_error", Kind: "signature_retry_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr.Error()), Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
...@@ -770,6 +779,7 @@ urlFallbackLoop: ...@@ -770,6 +779,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode, UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"), UpstreamRequestID: retryResp.Header.Get("x-request-id"),
Kind: kind, Kind: kind,
...@@ -817,6 +827,7 @@ urlFallbackLoop: ...@@ -817,6 +827,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover", Kind: "failover",
...@@ -1371,6 +1382,7 @@ urlFallbackLoop: ...@@ -1371,6 +1382,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
Kind: "request_error", Kind: "request_error",
Message: safeErr, Message: safeErr,
...@@ -1412,6 +1424,7 @@ urlFallbackLoop: ...@@ -1412,6 +1424,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry", Kind: "retry",
...@@ -1442,6 +1455,7 @@ urlFallbackLoop: ...@@ -1442,6 +1455,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry", Kind: "retry",
...@@ -1543,6 +1557,7 @@ urlFallbackLoop: ...@@ -1543,6 +1557,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID, UpstreamRequestID: requestID,
Kind: "failover", Kind: "failover",
...@@ -1559,6 +1574,7 @@ urlFallbackLoop: ...@@ -1559,6 +1574,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID, UpstreamRequestID: requestID,
Kind: "http_error", Kind: "http_error",
...@@ -2039,6 +2055,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou ...@@ -2039,6 +2055,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: upstreamStatus, UpstreamStatusCode: upstreamStatus,
UpstreamRequestID: upstreamRequestID, UpstreamRequestID: upstreamRequestID,
Kind: "http_error", Kind: "http_error",
......
...@@ -49,6 +49,9 @@ func (a *Account) IsSchedulableForModel(requestedModel string) bool { ...@@ -49,6 +49,9 @@ func (a *Account) IsSchedulableForModel(requestedModel string) bool {
if !a.IsSchedulable() { if !a.IsSchedulable() {
return false return false
} }
if a.isModelRateLimited(requestedModel) {
return false
}
if a.Platform != PlatformAntigravity { if a.Platform != PlatformAntigravity {
return true return true
} }
......
...@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("not an antigravity oauth account") return "", errors.New("not an antigravity oauth account")
} }
cacheKey := antigravityTokenCacheKey(account) cacheKey := AntigravityTokenCacheKey(account)
// 1. 先尝试缓存 // 1. 先尝试缓存
if p.tokenCache != nil { if p.tokenCache != nil {
...@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * ...@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return accessToken, nil return accessToken, nil
} }
func antigravityTokenCacheKey(account *Account) string { func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id")) projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" { if projectID != "" {
return "ag:" + projectID return "ag:" + projectID
......
...@@ -37,6 +37,11 @@ type APIKeyAuthGroupSnapshot struct { ...@@ -37,6 +37,11 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice4K *float64 `json:"image_price_4k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
} }
// APIKeyAuthCacheEntry 缓存条目,支持负缓存 // APIKeyAuthCacheEntry 缓存条目,支持负缓存
......
...@@ -207,20 +207,22 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ...@@ -207,20 +207,22 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
} }
if apiKey.Group != nil { if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{ snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID, ID: apiKey.Group.ID,
Name: apiKey.Group.Name, Name: apiKey.Group.Name,
Platform: apiKey.Group.Platform, Platform: apiKey.Group.Platform,
Status: apiKey.Group.Status, Status: apiKey.Group.Status,
SubscriptionType: apiKey.Group.SubscriptionType, SubscriptionType: apiKey.Group.SubscriptionType,
RateMultiplier: apiKey.Group.RateMultiplier, RateMultiplier: apiKey.Group.RateMultiplier,
DailyLimitUSD: apiKey.Group.DailyLimitUSD, DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K, ImagePrice4K: apiKey.Group.ImagePrice4K,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID, FallbackGroupID: apiKey.Group.FallbackGroupID,
ModelRouting: apiKey.Group.ModelRouting,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
} }
} }
return snapshot return snapshot
...@@ -248,21 +250,23 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ...@@ -248,21 +250,23 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
} }
if snapshot.Group != nil { if snapshot.Group != nil {
apiKey.Group = &Group{ apiKey.Group = &Group{
ID: snapshot.Group.ID, ID: snapshot.Group.ID,
Name: snapshot.Group.Name, Name: snapshot.Group.Name,
Platform: snapshot.Group.Platform, Platform: snapshot.Group.Platform,
Status: snapshot.Group.Status, Status: snapshot.Group.Status,
Hydrated: true, Hydrated: true,
SubscriptionType: snapshot.Group.SubscriptionType, SubscriptionType: snapshot.Group.SubscriptionType,
RateMultiplier: snapshot.Group.RateMultiplier, RateMultiplier: snapshot.Group.RateMultiplier,
DailyLimitUSD: snapshot.Group.DailyLimitUSD, DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K, ImagePrice4K: snapshot.Group.ImagePrice4K,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID, FallbackGroupID: snapshot.Group.FallbackGroupID,
ModelRouting: snapshot.Group.ModelRouting,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
} }
} }
return apiKey return apiKey
......
...@@ -172,12 +172,16 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { ...@@ -172,12 +172,16 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
Concurrency: 3, Concurrency: 3,
}, },
Group: &APIKeyAuthGroupSnapshot{ Group: &APIKeyAuthGroupSnapshot{
ID: groupID, ID: groupID,
Name: "g", Name: "g",
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
Status: StatusActive, Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard, SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1, RateMultiplier: 1,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-opus-*": {1, 2},
},
}, },
}, },
} }
...@@ -190,6 +194,8 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { ...@@ -190,6 +194,8 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
require.Equal(t, int64(1), apiKey.ID) require.Equal(t, int64(1), apiKey.ID)
require.Equal(t, int64(2), apiKey.User.ID) require.Equal(t, int64(2), apiKey.User.ID)
require.Equal(t, groupID, apiKey.Group.ID) require.Equal(t, groupID, apiKey.Group.ID)
require.True(t, apiKey.Group.ModelRoutingEnabled)
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
} }
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
......
package service
import (
"context"
"errors"
"log/slog"
"strconv"
"strings"
"time"
)
const (
claudeTokenRefreshSkew = 3 * time.Minute
claudeTokenCacheSkew = 5 * time.Minute
claudeLockWaitTime = 200 * time.Millisecond
)
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type ClaudeTokenCache = GeminiTokenCache
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
type ClaudeTokenProvider struct {
accountRepo AccountRepository
tokenCache ClaudeTokenCache
oauthService *OAuthService
}
func NewClaudeTokenProvider(
accountRepo AccountRepository,
tokenCache ClaudeTokenCache,
oauthService *OAuthService,
) *ClaudeTokenProvider {
return &ClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
oauthService: oauthService,
}
}
// GetAccessToken 获取有效的 access_token
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
}
cacheKey := ClaudeTokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
return token, nil
} else if err != nil {
slog.Warn("claude_token_cache_get_failed", "account_id", account.ID, "error", err)
}
}
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
// 2. 如果即将过期则刷新
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
// 从数据库获取最新账户信息
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
// 构建新 credentials,保留原有字段
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if lockErr != nil {
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
// 检查 ctx 是否已取消
if ctx.Err() != nil {
return "", ctx.Err()
}
// 从数据库获取最新账户信息
if p.accountRepo != nil {
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
}
expiresAt = account.GetCredentialAsTime("expires_at")
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
refreshFailed = true
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
refreshFailed = true
} else {
// 构建新 credentials,保留原有字段
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
account.Credentials = newCredentials
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time.Sleep(claudeLockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. 存入缓存
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl = time.Minute
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > claudeTokenCacheSkew:
ttl = until - claudeTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
return accessToken, nil
}
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// claudeTokenCacheStub implements ClaudeTokenCache for testing
type claudeTokenCacheStub struct {
mu sync.Mutex
tokens map[string]string
getErr error
setErr error
deleteErr error
lockAcquired bool
lockErr error
releaseLockErr error
getCalled int32
setCalled int32
lockCalled int32
unlockCalled int32
simulateLockRace bool
}
func newClaudeTokenCacheStub() *claudeTokenCacheStub {
return &claudeTokenCacheStub{
tokens: make(map[string]string),
lockAcquired: true,
}
}
func (s *claudeTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
atomic.AddInt32(&s.getCalled, 1)
if s.getErr != nil {
return "", s.getErr
}
s.mu.Lock()
defer s.mu.Unlock()
return s.tokens[cacheKey], nil
}
func (s *claudeTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
atomic.AddInt32(&s.setCalled, 1)
if s.setErr != nil {
return s.setErr
}
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[cacheKey] = token
return nil
}
func (s *claudeTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
if s.deleteErr != nil {
return s.deleteErr
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tokens, cacheKey)
return nil
}
func (s *claudeTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
atomic.AddInt32(&s.lockCalled, 1)
if s.lockErr != nil {
return false, s.lockErr
}
if s.simulateLockRace {
return false, nil
}
return s.lockAcquired, nil
}
func (s *claudeTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
atomic.AddInt32(&s.unlockCalled, 1)
return s.releaseLockErr
}
// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
type claudeAccountRepoStub struct {
account *Account
getErr error
updateErr error
getCalled int32
updateCalled int32
}
func (r *claudeAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
atomic.AddInt32(&r.getCalled, 1)
if r.getErr != nil {
return nil, r.getErr
}
return r.account, nil
}
func (r *claudeAccountRepoStub) Update(ctx context.Context, account *Account) error {
atomic.AddInt32(&r.updateCalled, 1)
if r.updateErr != nil {
return r.updateErr
}
r.account = account
return nil
}
// claudeOAuthServiceStub implements OAuthService methods for testing
type claudeOAuthServiceStub struct {
tokenInfo *TokenInfo
refreshErr error
refreshCalled int32
}
func (s *claudeOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
atomic.AddInt32(&s.refreshCalled, 1)
if s.refreshErr != nil {
return nil, s.refreshErr
}
return s.tokenInfo, nil
}
// testClaudeTokenProvider is a test version that uses the stub OAuth service
type testClaudeTokenProvider struct {
accountRepo *claudeAccountRepoStub
tokenCache *claudeTokenCacheStub
oauthService *claudeOAuthServiceStub
}
func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
}
cacheKey := ClaudeTokenCacheKey(account)
// 1. Check cache
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
// 2. Check if refresh needed
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// Check cache again after acquiring lock
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
// Get fresh account from DB
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
// Build new credentials
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_at"] = time.Now().Add(time.Duration(tokenInfo.ExpiresIn) * time.Second).Format(time.RFC3339)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
account.Credentials = newCredentials
_ = p.accountRepo.Update(ctx, account)
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if p.tokenCache.simulateLockRace {
// Wait and retry cache
time.Sleep(10 * time.Millisecond)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
}
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. Store in cache
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
ttl = time.Minute // 刷新失败时使用短 TTL
} else if expiresAt != nil {
until := time.Until(*expiresAt)
if until > claudeTokenCacheSkew {
ttl = until - claudeTokenCacheSkew
} else if until > 0 {
ttl = until
} else {
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func TestClaudeTokenProvider_CacheHit(t *testing.T) {
cache := newClaudeTokenCacheStub()
account := &Account{
ID: 100,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "db-token",
},
}
cacheKey := ClaudeTokenCacheKey(account)
cache.tokens[cacheKey] = "cached-token"
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "cached-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
}
func TestClaudeTokenProvider_CacheMiss_FromCredentials(t *testing.T) {
cache := newClaudeTokenCacheStub()
// Token expires in far future, no refresh needed
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 101,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "credential-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "credential-token", token)
// Should have stored in cache
cacheKey := ClaudeTokenCacheKey(account)
require.Equal(t, "credential-token", cache.tokens[cacheKey])
}
func TestClaudeTokenProvider_TokenRefresh(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh-token",
TokenType: "Bearer",
ExpiresIn: 3600,
ExpiresAt: time.Now().Add(time.Hour).Unix(),
},
}
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 102,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
}
func TestClaudeTokenProvider_LockRaceCondition(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.simulateLockRace = true
accountRepo := &claudeAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 103,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "race-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
// Simulate another worker already refreshed and cached
cacheKey := ClaudeTokenCacheKey(account)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestClaudeTokenProvider_NilAccount(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "account is nil")
require.Empty(t, token)
}
func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
account := &Account{
ID: 104,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Empty(t, token)
}
func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
account := &Account{
ID: 105,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Empty(t, token)
}
func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
account := &Account{
ID: 106,
Platform: PlatformAnthropic,
Type: AccountTypeSetupToken,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Empty(t, token)
}
func TestClaudeTokenProvider_NilCache(t *testing.T) {
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 107,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "nocache-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "nocache-token", token)
}
func TestClaudeTokenProvider_CacheGetError(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.getErr = errors.New("redis connection failed")
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 108,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
// Should gracefully degrade and return from credentials
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-token", token)
}
func TestClaudeTokenProvider_CacheSetError(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.setErr = errors.New("redis write failed")
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 109,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "still-works-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
// Should still work even if cache set fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "still-works-token", token)
}
func TestClaudeTokenProvider_MissingAccessToken(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 110,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// missing access_token
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestClaudeTokenProvider_RefreshError(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
refreshErr: errors.New("oauth refresh failed"),
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 111,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Now with fallback behavior, should return existing token even if refresh fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestClaudeTokenProvider_OAuthServiceNotConfigured(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 112,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: nil, // not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestClaudeTokenProvider_TTLCalculation(t *testing.T) {
tests := []struct {
name string
expiresIn time.Duration
}{
{
name: "far_future_expiry",
expiresIn: 1 * time.Hour,
},
{
name: "medium_expiry",
expiresIn: 10 * time.Minute,
},
{
name: "near_expiry",
expiresIn: 6 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
account := &Account{
ID: 200,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "test-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
_, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Verify token was cached
cacheKey := ClaudeTokenCacheKey(account)
require.Equal(t, "test-token", cache.tokens[cacheKey])
})
}
}
func TestClaudeTokenProvider_AccountRepoGetError(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{
getErr: errors.New("db connection failed"),
}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 113,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh",
"expires_at": expiresAt,
},
}
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Should still work, just using the passed-in account
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
}
func TestClaudeTokenProvider_AccountUpdateError(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{
updateErr: errors.New("db write failed"),
}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 114,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Should still return token even if update fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
}
func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 115,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-access-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
"custom_field": "should-be-preserved",
"organization": "test-org",
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "new-access-token", token)
// Verify existing fields are preserved
require.Equal(t, "should-be-preserved", accountRepo.account.Credentials["custom_field"])
require.Equal(t, "test-org", accountRepo.account.Credentials["organization"])
// Verify new fields are updated
require.Equal(t, "new-access-token", accountRepo.account.Credentials["access_token"])
require.Equal(t, "new-refresh-token", accountRepo.account.Credentials["refresh_token"])
}
func TestClaudeTokenProvider_DoubleCheckCacheAfterLock(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 116,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
cacheKey := ClaudeTokenCacheKey(account)
// After lock is acquired, cache should have the token (simulating another worker)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "cached-by-other-worker"
cache.mu.Unlock()
}()
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
// Tests for real provider - to increase coverage
func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 300,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey := ClaudeTokenCacheKey(account)
go func() {
time.Sleep(100 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "refreshed-by-other"
cache.mu.Unlock()
}()
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 301,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "original-token",
"expires_at": expiresAt,
},
}
cacheKey := ClaudeTokenCacheKey(account)
// Set token in cache immediately after wait starts
go func() {
time.Sleep(50 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestClaudeTokenProvider_Real_NoExpiresAt(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Prevent entering refresh logic
// Token with nil expires_at (no expiry set)
account := &Account{
ID: 302,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "no-expiry-token",
},
}
// After lock wait, return token from credentials
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "no-expiry-token", token)
}
func TestClaudeTokenProvider_Real_WhitespaceToken(t *testing.T) {
cache := newClaudeTokenCacheStub()
cacheKey := "claude:account:303"
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 303,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "real-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "real-token", token)
}
func TestClaudeTokenProvider_Real_EmptyCredentialToken(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 304,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": " ", // Whitespace only
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestClaudeTokenProvider_Real_LockError(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockErr = errors.New("redis lock failed")
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 305,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-on-lock-error",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-on-lock-error", token)
}
func TestClaudeTokenProvider_Real_NilCredentials(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 306,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// No access_token
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
...@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D ...@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil return stats, nil
} }
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) { func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID) trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
if err != nil { if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err) return nil, fmt.Errorf("get usage trend with filters: %w", err)
} }
return trend, nil return trend, nil
} }
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]usagestats.ModelStat, error) { func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, 0) stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
if err != nil { if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err) return nil, fmt.Errorf("get model stats with filters: %w", err)
} }
......
...@@ -142,6 +142,9 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6 ...@@ -142,6 +142,9 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil return nil
} }
func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil return nil
} }
...@@ -157,6 +160,9 @@ func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int6 ...@@ -157,6 +160,9 @@ func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int6
func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil return nil
} }
func (m *mockAccountRepoForPlatform) ClearModelRateLimits(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil return nil
} }
...@@ -1046,13 +1052,67 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1046,13 +1052,67 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, // No concurrency service concurrencyService: nil, // No concurrency service
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号") require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
}) })
t.Run("模型路由-无ConcurrencyService也生效", func(t *testing.T) {
groupID := int64(1)
sessionHash := "sticky"
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{
sessionBindings: map[string]int64{sessionHash: 1},
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: {
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-a": {1},
"claude-b": {2},
},
},
},
}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // legacy path
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "切换到 claude-b 时应按模型路由切换账号")
require.Equal(t, int64(2), cache.sessionBindings[sessionHash], "粘性绑定应更新为路由选择的账号")
})
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) { t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{ repo := &mockAccountRepoForPlatform{
accounts: []Account{ accounts: []Account{
...@@ -1077,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1077,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1109,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1109,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
} }
excludedIDs := map[int64]struct{}{1: {}} excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1143,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1143,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1179,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1179,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache), concurrencyService: NewConcurrencyService(concurrencyCache),
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1206,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1206,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.Error(t, err) require.Error(t, err)
require.Nil(t, result) require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts") require.Contains(t, err.Error(), "no available accounts")
...@@ -1238,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1238,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1271,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { ...@@ -1271,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, concurrencyService: nil,
} }
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil) result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.NotNil(t, result.Account) require.NotNil(t, result.Account)
...@@ -1341,6 +1401,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) ...@@ -1341,6 +1401,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T)
ID: groupID, ID: groupID,
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
Status: StatusActive, Status: StatusActive,
Hydrated: true,
} }
groupRepo := &mockGroupRepoForGateway{ groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group}, groups: map[int64]*Group{groupID: group},
...@@ -1398,6 +1459,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) { ...@@ -1398,6 +1459,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
ID: fallbackID, ID: fallbackID,
Platform: PlatformAnthropic, Platform: PlatformAnthropic,
Status: StatusActive, Status: StatusActive,
Hydrated: true,
} }
ctx = context.WithValue(ctx, ctxkey.Group, group) ctx = context.WithValue(ctx, ctxkey.Group, group)
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"os"
"regexp" "regexp"
"sort" "sort"
"strings" "strings"
...@@ -40,6 +41,21 @@ const ( ...@@ -40,6 +41,21 @@ const (
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
) )
func (s *GatewayService) debugModelRoutingEnabled() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func shortSessionHash(sessionHash string) string {
if sessionHash == "" {
return ""
}
if len(sessionHash) <= 8 {
return sessionHash
}
return sessionHash[:8]
}
// sseDataRe matches SSE data lines with optional whitespace after colon. // sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: "). // Some upstream APIs return non-standard "data:" without space (should be "data: ").
var ( var (
...@@ -196,6 +212,8 @@ type GatewayService struct { ...@@ -196,6 +212,8 @@ type GatewayService struct {
httpUpstream HTTPUpstream httpUpstream HTTPUpstream
deferredService *DeferredService deferredService *DeferredService
concurrencyService *ConcurrencyService concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
} }
// NewGatewayService creates a new GatewayService // NewGatewayService creates a new GatewayService
...@@ -215,6 +233,8 @@ func NewGatewayService( ...@@ -215,6 +233,8 @@ func NewGatewayService(
identityService *IdentityService, identityService *IdentityService,
httpUpstream HTTPUpstream, httpUpstream HTTPUpstream,
deferredService *DeferredService, deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache,
) *GatewayService { ) *GatewayService {
return &GatewayService{ return &GatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
...@@ -232,6 +252,8 @@ func NewGatewayService( ...@@ -232,6 +252,8 @@ func NewGatewayService(
identityService: identityService, identityService: identityService,
httpUpstream: httpUpstream, httpUpstream: httpUpstream,
deferredService: deferredService, deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
sessionLimitCache: sessionLimitCache,
} }
} }
...@@ -797,8 +819,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context ...@@ -797,8 +819,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
} }
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { // metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig() cfg := s.schedulingConfig()
// 提取会话 UUID(用于会话数量限制)
sessionUUID := extractSessionUUID(metadataUserID)
var stickyAccountID int64 var stickyAccountID int64
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
...@@ -813,6 +839,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -813,6 +839,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
} }
ctx = s.withGroupContext(ctx, group) ctx = s.withGroupContext(ctx, group)
if s.debugModelRoutingEnabled() && requestedModel != "" {
groupPlatform := ""
if group != nil {
groupPlatform = group.Platform
}
log.Printf("[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v",
derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil)
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled { if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil { if err != nil {
...@@ -856,6 +891,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -856,6 +891,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, err return nil, err
} }
preferOAuth := platform == PlatformGemini preferOAuth := platform == PlatformGemini
if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" {
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
}
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil { if err != nil {
...@@ -873,28 +911,242 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -873,28 +911,242 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return excluded return excluded
} }
// ============ Layer 1: 粘性会话优先 ============ // 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
if sessionHash != "" && s.cache != nil { accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
// 获取模型路由配置(仅 anthropic 平台)
var routingAccountIDs []int64
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
routingAccountIDs = group.GetRoutingAccountIDs(requestedModel)
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d",
group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID)
if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 {
keys := make([]string, 0, len(group.ModelRouting))
for k := range group.ModelRouting {
keys = append(keys, k)
}
sort.Strings(keys)
const maxKeys = 20
if len(keys) > maxKeys {
keys = keys[:maxKeys]
}
log.Printf("[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys)
}
}
}
// ============ Layer 1: 模型路由优先选择(优先级高于粘性会话) ============
if len(routingAccountIDs) > 0 && s.concurrencyService != nil {
// 1. 过滤出路由列表中可调度的账号
var routingCandidates []*Account
var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int
for _, routingAccountID := range routingAccountIDs {
if isExcluded(routingAccountID) {
filteredExcluded++
continue
}
account, ok := accountByID[routingAccountID]
if !ok || !account.IsSchedulable() {
if !ok {
filteredMissing++
} else {
filteredUnsched++
}
continue
}
if !s.isAccountAllowedForPlatform(account, platform, useMixed) {
filteredPlatform++
continue
}
if !account.IsSchedulableForModel(requestedModel) {
filteredModelScope++
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
filteredModelMapping++
continue
}
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
filteredWindowCost++
continue
}
routingCandidates = append(routingCandidates, account)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
}
if len(routingCandidates) > 0 {
// 1.5. 在路由账号范围内检查粘性会话
if sessionHash != "" && s.cache != nil {
stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
// 粘性账号在路由列表中,优先使用
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
if stickyAccount.IsSchedulable() &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
stickyAccount.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
result.ReleaseFunc() // 释放槽位
// 继续到负载感知选择
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
}
return &AccountSelectionResult{
Account: stickyAccount,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: stickyAccount,
WaitPlan: &AccountWaitPlan{
AccountID: stickyAccountID,
MaxConcurrency: stickyAccount.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
}
}
}
// 2. 批量获取负载信息
routingLoads := make([]AccountWithConcurrency, 0, len(routingCandidates))
for _, acc := range routingCandidates {
routingLoads = append(routingLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
// 3. 按负载感知排序
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var routingAvailable []accountWithLoad
for _, acc := range routingCandidates {
loadInfo := routingLoadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
routingAvailable = append(routingAvailable, accountWithLoad{account: acc, loadInfo: loadInfo})
}
}
if len(routingAvailable) > 0 {
// 排序:优先级 > 负载率 > 最后使用时间
sort.SliceStable(routingAvailable, func(i, j int) bool {
a, b := routingAvailable[i], routingAvailable[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
// 4. 尝试获取槽位
for _, item := range routingAvailable {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
acc := routingAvailable[0].account
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID)
}
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
}
}
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) { if err == nil && accountID > 0 && !isExcluded(accountID) {
// 粘性命中仅在当前可调度候选集中生效。
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
account, ok := accountByID[accountID] account, ok := accountByID[accountID]
if ok && s.isAccountInGroup(account, groupID) && if ok && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) && account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) // 会话数量限制检查
return &AccountSelectionResult{ if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
Account: account, result.ReleaseFunc() // 释放槽位,继续到 Layer 2
Acquired: true, } else {
ReleaseFunc: result.ReleaseFunc, _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
}, nil return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
} }
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
...@@ -935,6 +1187,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -935,6 +1187,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue continue
} }
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
candidates = append(candidates, acc) candidates = append(candidates, acc)
} }
...@@ -952,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -952,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil { if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
return result, nil return result, nil
} }
} else { } else {
...@@ -1001,6 +1257,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1001,6 +1257,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, item := range available { for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
} }
...@@ -1030,13 +1291,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro ...@@ -1030,13 +1291,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts") return nil, errors.New("no available accounts")
} }
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...) ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
for _, acc := range ordered { for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
} }
...@@ -1093,6 +1359,32 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (* ...@@ -1093,6 +1359,32 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
return group, nil return group, nil
} }
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
return nil
}
group, err := s.resolveGroupByID(ctx, *groupID)
if err != nil || group == nil {
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err)
}
return nil
}
// Preserve existing behavior: model routing only applies to anthropic groups.
if group.Platform != PlatformAnthropic {
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel)
}
return nil
}
ids := group.GetRoutingAccountIDs(requestedModel)
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v",
group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids)
}
return ids
}
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) { func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
if groupID == nil { if groupID == nil {
return nil, nil, nil return nil, nil, nil
...@@ -1242,6 +1534,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in ...@@ -1242,6 +1534,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
} }
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示可调度,false 表示不可调度
func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool {
// 只检查 Anthropic OAuth/SetupToken 账号
if !account.IsAnthropicOAuthOrSetupToken() {
return true
}
limit := account.GetWindowCostLimit()
if limit <= 0 {
return true // 未启用窗口费用限制
}
// 尝试从缓存获取窗口费用
var currentCost float64
if s.sessionLimitCache != nil {
if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit {
currentCost = cost
goto checkSchedulability
}
}
// 缓存未命中,从数据库查询
{
var startTime time.Time
if account.SessionWindowStart != nil {
startTime = *account.SessionWindowStart
} else {
startTime = time.Now().Add(-5 * time.Hour)
}
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil {
// 失败开放:查询失败时允许调度
return true
}
// 使用标准费用(不含账号倍率)
currentCost = stats.StandardCost
// 设置缓存(忽略错误)
if s.sessionLimitCache != nil {
_ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost)
}
}
checkSchedulability:
schedulability := account.CheckWindowCostSchedulability(currentCost)
switch schedulability {
case WindowCostSchedulable:
return true
case WindowCostStickyOnly:
return isSticky
case WindowCostNotSchedulable:
return false
}
return true
}
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
// 只检查 Anthropic OAuth/SetupToken 账号
if !account.IsAnthropicOAuthOrSetupToken() {
return true
}
maxSessions := account.GetMaxSessions()
if maxSessions <= 0 || sessionUUID == "" {
return true // 未启用会话限制或无会话ID
}
if s.sessionLimitCache == nil {
return true // 缓存不可用时允许通过
}
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
if err != nil {
// 失败开放:缓存错误时允许通过
return true
}
return allowed
}
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
// 格式: user_{64位hex}_account__session_{uuid}
func extractSessionUUID(metadataUserID string) string {
if metadataUserID == "" {
return ""
}
if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
return match[1]
}
return ""
}
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil { if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID) return s.schedulerSnapshot.GetAccount(ctx, accountID)
...@@ -1274,6 +1667,116 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { ...@@ -1274,6 +1667,116 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) // selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
var accounts []Account
accountsLoaded := false
// ============ Model Routing (legacy path): apply before sticky session ============
// When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing
// so switching model can switch upstream account within the same sticky session.
if len(routingAccountIDs) > 0 {
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs)
}
// 1) Sticky session only applies if the bound account is within the routing set.
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
return account, nil
}
}
}
}
// 2) Select an account from the routed candidates.
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform == "" {
hasForcePlatform = false
}
var err error
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
accountsLoaded = true
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
for _, id := range routingAccountIDs {
if id > 0 {
routingSet[id] = struct{}{}
}
}
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, ok := routingSet[acc.ID]; !ok {
continue
}
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected != nil {
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
}
return selected, nil
}
log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
}
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
...@@ -1292,13 +1795,16 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -1292,13 +1795,16 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
} }
// 2. 获取可调度账号列表(单平台) // 2. 获取可调度账号列表(单平台)
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) if !accountsLoaded {
if hasForcePlatform && forcePlatform == "" { forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
hasForcePlatform = false if hasForcePlatform && forcePlatform == "" {
} hasForcePlatform = false
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) }
if err != nil { var err error
return nil, fmt.Errorf("query accounts failed: %w", err) accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
} }
// 3. 按优先级+最久未用选择(考虑模型支持) // 3. 按优先级+最久未用选择(考虑模型支持)
...@@ -1364,6 +1870,115 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ...@@ -1364,6 +1870,115 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 // 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
preferOAuth := nativePlatform == PlatformGemini preferOAuth := nativePlatform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
var accounts []Account
accountsLoaded := false
// ============ Model Routing (legacy path): apply before sticky session ============
if len(routingAccountIDs) > 0 {
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs)
}
// 1) Sticky session only applies if the bound account is within the routing set.
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
return account, nil
}
}
}
}
}
// 2) Select an account from the routed candidates.
var err error
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
accountsLoaded = true
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
for _, id := range routingAccountIDs {
if id > 0 {
routingSet[id] = struct{}{}
}
}
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, ok := routingSet[acc.ID]; !ok {
continue
}
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected != nil {
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
}
return selected, nil
}
log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
}
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" && s.cache != nil { if sessionHash != "" && s.cache != nil {
...@@ -1385,9 +2000,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ...@@ -1385,9 +2000,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
} }
// 2. 获取可调度账号列表 // 2. 获取可调度账号列表
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, nativePlatform, false) if !accountsLoaded {
if err != nil { var err error
return nil, fmt.Errorf("query accounts failed: %w", err) accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
} }
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
...@@ -1488,6 +2106,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) ( ...@@ -1488,6 +2106,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
} }
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) { func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
// 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token
if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil {
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "oauth", nil
}
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
if accessToken == "" { if accessToken == "" {
return "", "", errors.New("access_token not found in credentials") return "", "", errors.New("access_token not found in credentials")
...@@ -1901,6 +2529,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1901,6 +2529,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart := time.Now() retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ { for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
// Capture upstream request body for ops retry of this attempt.
c.Set(OpsUpstreamRequestBodyKey, string(body))
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -1918,6 +2548,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1918,6 +2548,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
Kind: "request_error", Kind: "request_error",
Message: safeErr, Message: safeErr,
...@@ -1942,6 +2573,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1942,6 +2573,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "signature_error", Kind: "signature_error",
...@@ -1993,6 +2625,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -1993,6 +2625,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode, UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"), UpstreamRequestID: retryResp.Header.Get("x-request-id"),
Kind: "signature_retry_thinking", Kind: "signature_retry_thinking",
...@@ -2021,6 +2654,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2021,6 +2654,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0, UpstreamStatusCode: 0,
Kind: "signature_retry_tools_request_error", Kind: "signature_retry_tools_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
...@@ -2079,6 +2713,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2079,6 +2713,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry", Kind: "retry",
...@@ -2127,6 +2762,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2127,6 +2762,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry_exhausted_failover", Kind: "retry_exhausted_failover",
...@@ -2193,6 +2829,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A ...@@ -2193,6 +2829,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform, Platform: account.Platform,
AccountID: account.ID, AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode, UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"), UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover_on_400", Kind: "failover_on_400",
...@@ -3283,30 +3920,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ...@@ -3283,30 +3920,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.ImageSize != "" { if result.ImageSize != "" {
imageSize = &result.ImageSize imageSize = &result.ImageSize
} }
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
APIKeyID: apiKey.ID, APIKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
RequestID: result.RequestID, RequestID: result.RequestID,
Model: result.Model, Model: result.Model,
InputTokens: result.Usage.InputTokens, InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens, OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost, InputCost: cost.InputCost,
OutputCost: cost.OutputCost, OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost, CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost, CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost, TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost, ActualCost: cost.ActualCost,
RateMultiplier: multiplier, RateMultiplier: multiplier,
BillingType: billingType, AccountRateMultiplier: &accountRateMultiplier,
Stream: result.Stream, BillingType: billingType,
DurationMs: &durationMs, Stream: result.Stream,
FirstTokenMs: result.FirstTokenMs, DurationMs: &durationMs,
ImageCount: result.ImageCount, FirstTokenMs: result.FirstTokenMs,
ImageSize: imageSize, ImageCount: result.ImageCount,
CreatedAt: time.Now(), ImageSize: imageSize,
CreatedAt: time.Now(),
} }
// 添加 UserAgent // 添加 UserAgent
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment