Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6901b64f
Commit
6901b64f
authored
Jan 17, 2026
by
cyhhao
Browse files
merge: sync upstream changes
parents
32c47b15
dae0d532
Changes
189
Show whitespace changes
Inline
Side-by-side
backend/internal/server/routes/admin.go
View file @
6901b64f
...
...
@@ -81,6 +81,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ops
.
PUT
(
"/alert-rules/:id"
,
h
.
Admin
.
Ops
.
UpdateAlertRule
)
ops
.
DELETE
(
"/alert-rules/:id"
,
h
.
Admin
.
Ops
.
DeleteAlertRule
)
ops
.
GET
(
"/alert-events"
,
h
.
Admin
.
Ops
.
ListAlertEvents
)
ops
.
GET
(
"/alert-events/:id"
,
h
.
Admin
.
Ops
.
GetAlertEvent
)
ops
.
PUT
(
"/alert-events/:id/status"
,
h
.
Admin
.
Ops
.
UpdateAlertEventStatus
)
ops
.
POST
(
"/alert-silences"
,
h
.
Admin
.
Ops
.
CreateAlertSilence
)
// Email notification config (DB-backed)
ops
.
GET
(
"/email-notification/config"
,
h
.
Admin
.
Ops
.
GetEmailNotificationConfig
)
...
...
@@ -110,10 +113,26 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ws
.
GET
(
"/qps"
,
h
.
Admin
.
Ops
.
QPSWSHandler
)
}
// Error logs (
MVP-1
)
// Error logs (
legacy
)
ops
.
GET
(
"/errors"
,
h
.
Admin
.
Ops
.
GetErrorLogs
)
ops
.
GET
(
"/errors/:id"
,
h
.
Admin
.
Ops
.
GetErrorLogByID
)
ops
.
GET
(
"/errors/:id/retries"
,
h
.
Admin
.
Ops
.
ListRetryAttempts
)
ops
.
POST
(
"/errors/:id/retry"
,
h
.
Admin
.
Ops
.
RetryErrorRequest
)
ops
.
PUT
(
"/errors/:id/resolve"
,
h
.
Admin
.
Ops
.
UpdateErrorResolution
)
// Request errors (client-visible failures)
ops
.
GET
(
"/request-errors"
,
h
.
Admin
.
Ops
.
ListRequestErrors
)
ops
.
GET
(
"/request-errors/:id"
,
h
.
Admin
.
Ops
.
GetRequestError
)
ops
.
GET
(
"/request-errors/:id/upstream-errors"
,
h
.
Admin
.
Ops
.
ListRequestErrorUpstreamErrors
)
ops
.
POST
(
"/request-errors/:id/retry-client"
,
h
.
Admin
.
Ops
.
RetryRequestErrorClient
)
ops
.
POST
(
"/request-errors/:id/upstream-errors/:idx/retry"
,
h
.
Admin
.
Ops
.
RetryRequestErrorUpstreamEvent
)
ops
.
PUT
(
"/request-errors/:id/resolve"
,
h
.
Admin
.
Ops
.
ResolveRequestError
)
// Upstream errors (independent upstream failures)
ops
.
GET
(
"/upstream-errors"
,
h
.
Admin
.
Ops
.
ListUpstreamErrors
)
ops
.
GET
(
"/upstream-errors/:id"
,
h
.
Admin
.
Ops
.
GetUpstreamError
)
ops
.
POST
(
"/upstream-errors/:id/retry"
,
h
.
Admin
.
Ops
.
RetryUpstreamError
)
ops
.
PUT
(
"/upstream-errors/:id/resolve"
,
h
.
Admin
.
Ops
.
ResolveUpstreamError
)
// Request drilldown (success + error)
ops
.
GET
(
"/requests"
,
h
.
Admin
.
Ops
.
ListRequestDetails
)
...
...
@@ -250,6 +269,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies
.
POST
(
"/:id/test"
,
h
.
Admin
.
Proxy
.
Test
)
proxies
.
GET
(
"/:id/stats"
,
h
.
Admin
.
Proxy
.
GetStats
)
proxies
.
GET
(
"/:id/accounts"
,
h
.
Admin
.
Proxy
.
GetProxyAccounts
)
proxies
.
POST
(
"/batch-delete"
,
h
.
Admin
.
Proxy
.
BatchDelete
)
proxies
.
POST
(
"/batch"
,
h
.
Admin
.
Proxy
.
BatchCreate
)
}
}
...
...
backend/internal/service/account.go
View file @
6901b64f
...
...
@@ -19,6 +19,9 @@ type Account struct {
ProxyID
*
int64
Concurrency
int
Priority
int
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
RateMultiplier
*
float64
Status
string
ErrorMessage
string
LastUsedAt
*
time
.
Time
...
...
@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool {
return
a
.
Status
==
StatusActive
}
// BillingRateMultiplier 返回账号计费倍率。
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
// - 允许 0,表示该账号计费为 0
// - 负数属于非法数据,出于安全考虑按 1.0 处理
func
(
a
*
Account
)
BillingRateMultiplier
()
float64
{
if
a
==
nil
||
a
.
RateMultiplier
==
nil
{
return
1.0
}
if
*
a
.
RateMultiplier
<
0
{
return
1.0
}
return
*
a
.
RateMultiplier
}
func
(
a
*
Account
)
IsSchedulable
()
bool
{
if
!
a
.
IsActive
()
||
!
a
.
Schedulable
{
return
false
...
...
@@ -556,3 +573,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
}
return
false
}
// WindowCostSchedulability 窗口费用调度状态
type
WindowCostSchedulability
int
const
(
// WindowCostSchedulable 可正常调度
WindowCostSchedulable
WindowCostSchedulability
=
iota
// WindowCostStickyOnly 仅允许粘性会话
WindowCostStickyOnly
// WindowCostNotSchedulable 完全不可调度
WindowCostNotSchedulable
)
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
func
(
a
*
Account
)
IsAnthropicOAuthOrSetupToken
()
bool
{
return
a
.
Platform
==
PlatformAnthropic
&&
(
a
.
Type
==
AccountTypeOAuth
||
a
.
Type
==
AccountTypeSetupToken
)
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func
(
a
*
Account
)
GetWindowCostLimit
()
float64
{
if
a
.
Extra
==
nil
{
return
0
}
if
v
,
ok
:=
a
.
Extra
[
"window_cost_limit"
];
ok
{
return
parseExtraFloat64
(
v
)
}
return
0
}
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
// 默认值为 10
func
(
a
*
Account
)
GetWindowCostStickyReserve
()
float64
{
if
a
.
Extra
==
nil
{
return
10.0
}
if
v
,
ok
:=
a
.
Extra
[
"window_cost_sticky_reserve"
];
ok
{
val
:=
parseExtraFloat64
(
v
)
if
val
>
0
{
return
val
}
}
return
10.0
}
// GetMaxSessions 获取最大并发会话数
// 返回 0 表示未启用
func
(
a
*
Account
)
GetMaxSessions
()
int
{
if
a
.
Extra
==
nil
{
return
0
}
if
v
,
ok
:=
a
.
Extra
[
"max_sessions"
];
ok
{
return
parseExtraInt
(
v
)
}
return
0
}
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
// 默认值为 5 分钟
func
(
a
*
Account
)
GetSessionIdleTimeoutMinutes
()
int
{
if
a
.
Extra
==
nil
{
return
5
}
if
v
,
ok
:=
a
.
Extra
[
"session_idle_timeout_minutes"
];
ok
{
val
:=
parseExtraInt
(
v
)
if
val
>
0
{
return
val
}
}
return
5
}
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
func
(
a
*
Account
)
CheckWindowCostSchedulability
(
currentWindowCost
float64
)
WindowCostSchedulability
{
limit
:=
a
.
GetWindowCostLimit
()
if
limit
<=
0
{
return
WindowCostSchedulable
}
if
currentWindowCost
<
limit
{
return
WindowCostSchedulable
}
stickyReserve
:=
a
.
GetWindowCostStickyReserve
()
if
currentWindowCost
<
limit
+
stickyReserve
{
return
WindowCostStickyOnly
}
return
WindowCostNotSchedulable
}
// parseExtraFloat64 从 extra 字段解析 float64 值
func
parseExtraFloat64
(
value
any
)
float64
{
switch
v
:=
value
.
(
type
)
{
case
float64
:
return
v
case
float32
:
return
float64
(
v
)
case
int
:
return
float64
(
v
)
case
int64
:
return
float64
(
v
)
case
json
.
Number
:
if
f
,
err
:=
v
.
Float64
();
err
==
nil
{
return
f
}
case
string
:
if
f
,
err
:=
strconv
.
ParseFloat
(
strings
.
TrimSpace
(
v
),
64
);
err
==
nil
{
return
f
}
}
return
0
}
// parseExtraInt 从 extra 字段解析 int 值
func
parseExtraInt
(
value
any
)
int
{
switch
v
:=
value
.
(
type
)
{
case
int
:
return
v
case
int64
:
return
int
(
v
)
case
float64
:
return
int
(
v
)
case
json
.
Number
:
if
i
,
err
:=
v
.
Int64
();
err
==
nil
{
return
int
(
i
)
}
case
string
:
if
i
,
err
:=
strconv
.
Atoi
(
strings
.
TrimSpace
(
v
));
err
==
nil
{
return
i
}
}
return
0
}
backend/internal/service/account_billing_rate_multiplier_test.go
0 → 100644
View file @
6901b64f
package
service
import
(
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func
TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil
(
t
*
testing
.
T
)
{
var
a
Account
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
`{"id":1,"name":"acc","status":"active"}`
),
&
a
))
require
.
Nil
(
t
,
a
.
RateMultiplier
)
require
.
Equal
(
t
,
1.0
,
a
.
BillingRateMultiplier
())
}
func
TestAccount_BillingRateMultiplier_AllowsZero
(
t
*
testing
.
T
)
{
v
:=
0.0
a
:=
Account
{
RateMultiplier
:
&
v
}
require
.
Equal
(
t
,
0.0
,
a
.
BillingRateMultiplier
())
}
func
TestAccount_BillingRateMultiplier_NegativeFallsBackToOne
(
t
*
testing
.
T
)
{
v
:=
-
1.0
a
:=
Account
{
RateMultiplier
:
&
v
}
require
.
Equal
(
t
,
1.0
,
a
.
BillingRateMultiplier
())
}
backend/internal/service/account_service.go
View file @
6901b64f
...
...
@@ -50,11 +50,13 @@ type AccountRepository interface {
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
ClearTempUnschedulable
(
ctx
context
.
Context
,
id
int64
)
error
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
ClearModelRateLimits
(
ctx
context
.
Context
,
id
int64
)
error
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
AccountBulkUpdate
)
(
int64
,
error
)
...
...
@@ -67,6 +69,7 @@ type AccountBulkUpdate struct {
ProxyID
*
int64
Concurrency
*
int
Priority
*
int
RateMultiplier
*
float64
Status
*
string
Schedulable
*
bool
Credentials
map
[
string
]
any
...
...
backend/internal/service/account_service_delete_test.go
View file @
6901b64f
...
...
@@ -143,6 +143,10 @@ func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id
panic
(
"unexpected SetAntigravityQuotaScopeLimit call"
)
}
func
(
s
*
accountRepoStub
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
panic
(
"unexpected SetModelRateLimit call"
)
}
func
(
s
*
accountRepoStub
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
panic
(
"unexpected SetOverloaded call"
)
}
...
...
@@ -163,6 +167,10 @@ func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id in
panic
(
"unexpected ClearAntigravityQuotaScopes call"
)
}
func
(
s
*
accountRepoStub
)
ClearModelRateLimits
(
ctx
context
.
Context
,
id
int64
)
error
{
panic
(
"unexpected ClearModelRateLimits call"
)
}
func
(
s
*
accountRepoStub
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
panic
(
"unexpected UpdateSessionWindow call"
)
}
...
...
backend/internal/service/account_usage_service.go
View file @
6901b64f
...
...
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
int64
)
([]
usagestats
.
TrendDataPoint
,
error
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
int64
)
([]
usagestats
.
ModelStat
,
error
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
APIKeyUsageTrendPoint
,
error
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
UserUsageTrendPoint
,
error
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
...
...
@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache {
}
// WindowStats 窗口期统计
//
// cost: 账号口径费用(total_cost * account_rate_multiplier)
// standard_cost: 标准费用(total_cost,不含倍率)
// user_cost: 用户/API Key 口径费用(actual_cost,受分组倍率影响)
type
WindowStats
struct
{
Requests
int64
`json:"requests"`
Tokens
int64
`json:"tokens"`
Cost
float64
`json:"cost"`
StandardCost
float64
`json:"standard_cost"`
UserCost
float64
`json:"user_cost"`
}
// UsageProgress 使用量进度
...
...
@@ -266,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart
:=
geminiDailyWindowStart
(
now
)
stats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
dayStart
,
now
,
0
,
0
,
account
.
ID
)
stats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
dayStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get gemini usage stats failed: %w"
,
err
)
}
...
...
@@ -288,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart
:=
now
.
Truncate
(
time
.
Minute
)
minuteResetAt
:=
minuteStart
.
Add
(
time
.
Minute
)
minuteStats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
minuteStart
,
now
,
0
,
0
,
account
.
ID
)
minuteStats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
minuteStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get gemini minute usage stats failed: %w"
,
err
)
}
...
...
@@ -380,6 +386,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
Requests
:
stats
.
Requests
,
Tokens
:
stats
.
Tokens
,
Cost
:
stats
.
Cost
,
StandardCost
:
stats
.
StandardCost
,
UserCost
:
stats
.
UserCost
,
}
// 缓存窗口统计(1 分钟)
...
...
@@ -406,6 +414,8 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
Requests
:
stats
.
Requests
,
Tokens
:
stats
.
Tokens
,
Cost
:
stats
.
Cost
,
StandardCost
:
stats
.
StandardCost
,
UserCost
:
stats
.
UserCost
,
},
nil
}
...
...
@@ -565,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
},
}
}
// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
// 用于账号列表页面显示当前窗口费用
func
(
s
*
AccountUsageService
)
GetAccountWindowStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
time
.
Time
)
(
*
usagestats
.
AccountStats
,
error
)
{
return
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
accountID
,
startTime
)
}
backend/internal/service/admin_service.go
View file @
6901b64f
...
...
@@ -54,7 +54,8 @@ type AdminService interface {
CreateProxy
(
ctx
context
.
Context
,
input
*
CreateProxyInput
)
(
*
Proxy
,
error
)
UpdateProxy
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateProxyInput
)
(
*
Proxy
,
error
)
DeleteProxy
(
ctx
context
.
Context
,
id
int64
)
error
GetProxyAccounts
(
ctx
context
.
Context
,
proxyID
int64
,
page
,
pageSize
int
)
([]
Account
,
int64
,
error
)
BatchDeleteProxies
(
ctx
context
.
Context
,
ids
[]
int64
)
(
*
ProxyBatchDeleteResult
,
error
)
GetProxyAccounts
(
ctx
context
.
Context
,
proxyID
int64
)
([]
ProxyAccountSummary
,
error
)
CheckProxyExists
(
ctx
context
.
Context
,
host
string
,
port
int
,
username
,
password
string
)
(
bool
,
error
)
TestProxy
(
ctx
context
.
Context
,
id
int64
)
(
*
ProxyTestResult
,
error
)
...
...
@@ -105,6 +106,9 @@ type CreateGroupInput struct {
ImagePrice4K
*
float64
ClaudeCodeOnly
bool
// 仅允许 Claude Code 客户端
FallbackGroupID
*
int64
// 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
bool
// 是否启用模型路由
}
type
UpdateGroupInput
struct
{
...
...
@@ -124,6 +128,9 @@ type UpdateGroupInput struct {
ImagePrice4K
*
float64
ClaudeCodeOnly
*
bool
// 仅允许 Claude Code 客户端
FallbackGroupID
*
int64
// 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
*
bool
// 是否启用模型路由
}
type
CreateAccountInput
struct
{
...
...
@@ -136,6 +143,7 @@ type CreateAccountInput struct {
ProxyID
*
int64
Concurrency
int
Priority
int
RateMultiplier
*
float64
// 账号计费倍率(>=0,允许 0)
GroupIDs
[]
int64
ExpiresAt
*
int64
AutoPauseOnExpired
*
bool
...
...
@@ -153,6 +161,7 @@ type UpdateAccountInput struct {
ProxyID
*
int64
Concurrency
*
int
// 使用指针区分"未提供"和"设置为0"
Priority
*
int
// 使用指针区分"未提供"和"设置为0"
RateMultiplier
*
float64
// 账号计费倍率(>=0,允许 0)
Status
string
GroupIDs
*
[]
int64
ExpiresAt
*
int64
...
...
@@ -167,6 +176,7 @@ type BulkUpdateAccountsInput struct {
ProxyID
*
int64
Concurrency
*
int
Priority
*
int
RateMultiplier
*
float64
// 账号计费倍率(>=0,允许 0)
Status
string
Schedulable
*
bool
GroupIDs
*
[]
int64
...
...
@@ -220,6 +230,16 @@ type GenerateRedeemCodesInput struct {
ValidityDays
int
// 订阅类型专用:有效天数
}
type
ProxyBatchDeleteResult
struct
{
DeletedIDs
[]
int64
`json:"deleted_ids"`
Skipped
[]
ProxyBatchDeleteSkipped
`json:"skipped"`
}
type
ProxyBatchDeleteSkipped
struct
{
ID
int64
`json:"id"`
Reason
string
`json:"reason"`
}
// ProxyTestResult represents the result of testing a proxy
type
ProxyTestResult
struct
{
Success
bool
`json:"success"`
...
...
@@ -229,14 +249,16 @@ type ProxyTestResult struct {
City
string
`json:"city,omitempty"`
Region
string
`json:"region,omitempty"`
Country
string
`json:"country,omitempty"`
CountryCode
string
`json:"country_code,omitempty"`
}
// ProxyExitInfo represents proxy exit information from ip
info.io
// ProxyExitInfo represents proxy exit information from ip
-api.com
type
ProxyExitInfo
struct
{
IP
string
City
string
Region
string
Country
string
CountryCode
string
}
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
...
...
@@ -254,6 +276,7 @@ type adminServiceImpl struct {
redeemCodeRepo
RedeemCodeRepository
billingCacheService
*
BillingCacheService
proxyProber
ProxyExitInfoProber
proxyLatencyCache
ProxyLatencyCache
authCacheInvalidator
APIKeyAuthCacheInvalidator
}
...
...
@@ -267,6 +290,7 @@ func NewAdminService(
redeemCodeRepo
RedeemCodeRepository
,
billingCacheService
*
BillingCacheService
,
proxyProber
ProxyExitInfoProber
,
proxyLatencyCache
ProxyLatencyCache
,
authCacheInvalidator
APIKeyAuthCacheInvalidator
,
)
AdminService
{
return
&
adminServiceImpl
{
...
...
@@ -278,6 +302,7 @@ func NewAdminService(
redeemCodeRepo
:
redeemCodeRepo
,
billingCacheService
:
billingCacheService
,
proxyProber
:
proxyProber
,
proxyLatencyCache
:
proxyLatencyCache
,
authCacheInvalidator
:
authCacheInvalidator
,
}
}
...
...
@@ -562,6 +587,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice4K
:
imagePrice4K
,
ClaudeCodeOnly
:
input
.
ClaudeCodeOnly
,
FallbackGroupID
:
input
.
FallbackGroupID
,
ModelRouting
:
input
.
ModelRouting
,
}
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -690,6 +716,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
}
// 模型路由配置
if
input
.
ModelRouting
!=
nil
{
group
.
ModelRouting
=
input
.
ModelRouting
}
if
input
.
ModelRoutingEnabled
!=
nil
{
group
.
ModelRoutingEnabled
=
*
input
.
ModelRoutingEnabled
}
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -817,6 +851,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
else
{
account
.
AutoPauseOnExpired
=
true
}
if
input
.
RateMultiplier
!=
nil
{
if
*
input
.
RateMultiplier
<
0
{
return
nil
,
errors
.
New
(
"rate_multiplier must be >= 0"
)
}
account
.
RateMultiplier
=
input
.
RateMultiplier
}
if
err
:=
s
.
accountRepo
.
Create
(
ctx
,
account
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -869,6 +909,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if
input
.
Priority
!=
nil
{
account
.
Priority
=
*
input
.
Priority
}
if
input
.
RateMultiplier
!=
nil
{
if
*
input
.
RateMultiplier
<
0
{
return
nil
,
errors
.
New
(
"rate_multiplier must be >= 0"
)
}
account
.
RateMultiplier
=
input
.
RateMultiplier
}
if
input
.
Status
!=
""
{
account
.
Status
=
input
.
Status
}
...
...
@@ -942,6 +988,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
}
if
input
.
RateMultiplier
!=
nil
{
if
*
input
.
RateMultiplier
<
0
{
return
nil
,
errors
.
New
(
"rate_multiplier must be >= 0"
)
}
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates
:=
AccountBulkUpdate
{
Credentials
:
input
.
Credentials
,
...
...
@@ -959,6 +1011,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if
input
.
Priority
!=
nil
{
repoUpdates
.
Priority
=
input
.
Priority
}
if
input
.
RateMultiplier
!=
nil
{
repoUpdates
.
RateMultiplier
=
input
.
RateMultiplier
}
if
input
.
Status
!=
""
{
repoUpdates
.
Status
=
&
input
.
Status
}
...
...
@@ -1069,6 +1124,7 @@ func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page
if
err
!=
nil
{
return
nil
,
0
,
err
}
s
.
attachProxyLatency
(
ctx
,
proxies
)
return
proxies
,
result
.
Total
,
nil
}
...
...
@@ -1077,7 +1133,12 @@ func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
}
func
(
s
*
adminServiceImpl
)
GetAllProxiesWithAccountCount
(
ctx
context
.
Context
)
([]
ProxyWithAccountCount
,
error
)
{
return
s
.
proxyRepo
.
ListActiveWithAccountCount
(
ctx
)
proxies
,
err
:=
s
.
proxyRepo
.
ListActiveWithAccountCount
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
s
.
attachProxyLatency
(
ctx
,
proxies
)
return
proxies
,
nil
}
func
(
s
*
adminServiceImpl
)
GetProxy
(
ctx
context
.
Context
,
id
int64
)
(
*
Proxy
,
error
)
{
...
...
@@ -1097,6 +1158,8 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
if
err
:=
s
.
proxyRepo
.
Create
(
ctx
,
proxy
);
err
!=
nil
{
return
nil
,
err
}
// Probe latency asynchronously so creation isn't blocked by network timeout.
go
s
.
probeProxyLatency
(
context
.
Background
(),
proxy
)
return
proxy
,
nil
}
...
...
@@ -1135,12 +1198,53 @@ func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *Upd
}
func
(
s
*
adminServiceImpl
)
DeleteProxy
(
ctx
context
.
Context
,
id
int64
)
error
{
count
,
err
:=
s
.
proxyRepo
.
CountAccountsByProxyID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
count
>
0
{
return
ErrProxyInUse
}
return
s
.
proxyRepo
.
Delete
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
GetProxyAccounts
(
ctx
context
.
Context
,
proxyID
int64
,
page
,
pageSize
int
)
([]
Account
,
int64
,
error
)
{
// Return mock data for now - would need a dedicated repository method
return
[]
Account
{},
0
,
nil
func
(
s
*
adminServiceImpl
)
BatchDeleteProxies
(
ctx
context
.
Context
,
ids
[]
int64
)
(
*
ProxyBatchDeleteResult
,
error
)
{
result
:=
&
ProxyBatchDeleteResult
{}
if
len
(
ids
)
==
0
{
return
result
,
nil
}
for
_
,
id
:=
range
ids
{
count
,
err
:=
s
.
proxyRepo
.
CountAccountsByProxyID
(
ctx
,
id
)
if
err
!=
nil
{
result
.
Skipped
=
append
(
result
.
Skipped
,
ProxyBatchDeleteSkipped
{
ID
:
id
,
Reason
:
err
.
Error
(),
})
continue
}
if
count
>
0
{
result
.
Skipped
=
append
(
result
.
Skipped
,
ProxyBatchDeleteSkipped
{
ID
:
id
,
Reason
:
ErrProxyInUse
.
Error
(),
})
continue
}
if
err
:=
s
.
proxyRepo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
result
.
Skipped
=
append
(
result
.
Skipped
,
ProxyBatchDeleteSkipped
{
ID
:
id
,
Reason
:
err
.
Error
(),
})
continue
}
result
.
DeletedIDs
=
append
(
result
.
DeletedIDs
,
id
)
}
return
result
,
nil
}
func
(
s
*
adminServiceImpl
)
GetProxyAccounts
(
ctx
context
.
Context
,
proxyID
int64
)
([]
ProxyAccountSummary
,
error
)
{
return
s
.
proxyRepo
.
ListAccountSummariesByProxyID
(
ctx
,
proxyID
)
}
func
(
s
*
adminServiceImpl
)
CheckProxyExists
(
ctx
context
.
Context
,
host
string
,
port
int
,
username
,
password
string
)
(
bool
,
error
)
{
...
...
@@ -1240,12 +1344,29 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
proxyURL
:=
proxy
.
URL
()
exitInfo
,
latencyMs
,
err
:=
s
.
proxyProber
.
ProbeProxy
(
ctx
,
proxyURL
)
if
err
!=
nil
{
s
.
saveProxyLatency
(
ctx
,
id
,
&
ProxyLatencyInfo
{
Success
:
false
,
Message
:
err
.
Error
(),
UpdatedAt
:
time
.
Now
(),
})
return
&
ProxyTestResult
{
Success
:
false
,
Message
:
err
.
Error
(),
},
nil
}
latency
:=
latencyMs
s
.
saveProxyLatency
(
ctx
,
id
,
&
ProxyLatencyInfo
{
Success
:
true
,
LatencyMs
:
&
latency
,
Message
:
"Proxy is accessible"
,
IPAddress
:
exitInfo
.
IP
,
Country
:
exitInfo
.
Country
,
CountryCode
:
exitInfo
.
CountryCode
,
Region
:
exitInfo
.
Region
,
City
:
exitInfo
.
City
,
UpdatedAt
:
time
.
Now
(),
})
return
&
ProxyTestResult
{
Success
:
true
,
Message
:
"Proxy is accessible"
,
...
...
@@ -1254,9 +1375,38 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
City
:
exitInfo
.
City
,
Region
:
exitInfo
.
Region
,
Country
:
exitInfo
.
Country
,
CountryCode
:
exitInfo
.
CountryCode
,
},
nil
}
func
(
s
*
adminServiceImpl
)
probeProxyLatency
(
ctx
context
.
Context
,
proxy
*
Proxy
)
{
if
s
.
proxyProber
==
nil
||
proxy
==
nil
{
return
}
exitInfo
,
latencyMs
,
err
:=
s
.
proxyProber
.
ProbeProxy
(
ctx
,
proxy
.
URL
())
if
err
!=
nil
{
s
.
saveProxyLatency
(
ctx
,
proxy
.
ID
,
&
ProxyLatencyInfo
{
Success
:
false
,
Message
:
err
.
Error
(),
UpdatedAt
:
time
.
Now
(),
})
return
}
latency
:=
latencyMs
s
.
saveProxyLatency
(
ctx
,
proxy
.
ID
,
&
ProxyLatencyInfo
{
Success
:
true
,
LatencyMs
:
&
latency
,
Message
:
"Proxy is accessible"
,
IPAddress
:
exitInfo
.
IP
,
Country
:
exitInfo
.
Country
,
CountryCode
:
exitInfo
.
CountryCode
,
Region
:
exitInfo
.
Region
,
City
:
exitInfo
.
City
,
UpdatedAt
:
time
.
Now
(),
})
}
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
// 如果存在混合,返回错误提示用户确认
func
(
s
*
adminServiceImpl
)
checkMixedChannelRisk
(
ctx
context
.
Context
,
currentAccountID
int64
,
currentAccountPlatform
string
,
groupIDs
[]
int64
)
error
{
...
...
@@ -1306,6 +1456,51 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return
nil
}
func
(
s
*
adminServiceImpl
)
attachProxyLatency
(
ctx
context
.
Context
,
proxies
[]
ProxyWithAccountCount
)
{
if
s
.
proxyLatencyCache
==
nil
||
len
(
proxies
)
==
0
{
return
}
ids
:=
make
([]
int64
,
0
,
len
(
proxies
))
for
i
:=
range
proxies
{
ids
=
append
(
ids
,
proxies
[
i
]
.
ID
)
}
latencies
,
err
:=
s
.
proxyLatencyCache
.
GetProxyLatencies
(
ctx
,
ids
)
if
err
!=
nil
{
log
.
Printf
(
"Warning: load proxy latency cache failed: %v"
,
err
)
return
}
for
i
:=
range
proxies
{
info
:=
latencies
[
proxies
[
i
]
.
ID
]
if
info
==
nil
{
continue
}
if
info
.
Success
{
proxies
[
i
]
.
LatencyStatus
=
"success"
proxies
[
i
]
.
LatencyMs
=
info
.
LatencyMs
}
else
{
proxies
[
i
]
.
LatencyStatus
=
"failed"
}
proxies
[
i
]
.
LatencyMessage
=
info
.
Message
proxies
[
i
]
.
IPAddress
=
info
.
IPAddress
proxies
[
i
]
.
Country
=
info
.
Country
proxies
[
i
]
.
CountryCode
=
info
.
CountryCode
proxies
[
i
]
.
Region
=
info
.
Region
proxies
[
i
]
.
City
=
info
.
City
}
}
func
(
s
*
adminServiceImpl
)
saveProxyLatency
(
ctx
context
.
Context
,
proxyID
int64
,
info
*
ProxyLatencyInfo
)
{
if
s
.
proxyLatencyCache
==
nil
||
info
==
nil
{
return
}
if
err
:=
s
.
proxyLatencyCache
.
SetProxyLatency
(
ctx
,
proxyID
,
info
);
err
!=
nil
{
log
.
Printf
(
"Warning: store proxy latency cache failed: %v"
,
err
)
}
}
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func
getAccountPlatform
(
accountPlatform
string
)
string
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
accountPlatform
))
{
...
...
backend/internal/service/admin_service_bulk_update_test.go
View file @
6901b64f
backend/internal/service/admin_service_delete_test.go
View file @
6901b64f
...
...
@@ -154,6 +154,8 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
type
proxyRepoStub
struct
{
deleteErr
error
countErr
error
accountCount
int64
deletedIDs
[]
int64
}
...
...
@@ -199,7 +201,14 @@ func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, p
}
func
(
s
*
proxyRepoStub
)
CountAccountsByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
(
int64
,
error
)
{
panic
(
"unexpected CountAccountsByProxyID call"
)
if
s
.
countErr
!=
nil
{
return
0
,
s
.
countErr
}
return
s
.
accountCount
,
nil
}
func
(
s
*
proxyRepoStub
)
ListAccountSummariesByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
([]
ProxyAccountSummary
,
error
)
{
panic
(
"unexpected ListAccountSummariesByProxyID call"
)
}
type
redeemRepoStub
struct
{
...
...
@@ -409,6 +418,15 @@ func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
require
.
Equal
(
t
,
[]
int64
{
404
},
repo
.
deletedIDs
)
}
func
TestAdminService_DeleteProxy_InUse
(
t
*
testing
.
T
)
{
repo
:=
&
proxyRepoStub
{
accountCount
:
2
}
svc
:=
&
adminServiceImpl
{
proxyRepo
:
repo
}
err
:=
svc
.
DeleteProxy
(
context
.
Background
(),
77
)
require
.
ErrorIs
(
t
,
err
,
ErrProxyInUse
)
require
.
Empty
(
t
,
repo
.
deletedIDs
)
}
func
TestAdminService_DeleteProxy_Error
(
t
*
testing
.
T
)
{
deleteErr
:=
errors
.
New
(
"delete failed"
)
repo
:=
&
proxyRepoStub
{
deleteErr
:
deleteErr
}
...
...
backend/internal/service/antigravity_gateway_service.go
View file @
6901b64f
...
...
@@ -564,6 +564,10 @@ urlFallbackLoop:
}
upstreamReq
,
err
:=
antigravity
.
NewAPIRequestWithURL
(
ctx
,
baseURL
,
action
,
accessToken
,
geminiBody
)
// Capture upstream request body for ops retry of this attempt.
if
c
!=
nil
{
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
geminiBody
))
}
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -574,6 +578,7 @@ urlFallbackLoop:
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
...
...
@@ -615,6 +620,7 @@ urlFallbackLoop:
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"retry"
,
...
...
@@ -645,6 +651,7 @@ urlFallbackLoop:
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"retry"
,
...
...
@@ -697,6 +704,7 @@ urlFallbackLoop:
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"signature_error"
,
...
...
@@ -740,6 +748,7 @@ urlFallbackLoop:
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"signature_retry_request_error"
,
Message
:
sanitizeUpstreamErrorMessage
(
retryErr
.
Error
()),
...
...
@@ -770,6 +779,7 @@ urlFallbackLoop:
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
retryResp
.
StatusCode
,
UpstreamRequestID
:
retryResp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
kind
,
...
...
@@ -817,6 +827,7 @@ urlFallbackLoop:
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"failover"
,
...
...
@@ -1371,6 +1382,7 @@ urlFallbackLoop:
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
...
...
@@ -1412,6 +1424,7 @@ urlFallbackLoop:
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"retry"
,
...
...
@@ -1442,6 +1455,7 @@ urlFallbackLoop:
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"retry"
,
...
...
@@ -1543,6 +1557,7 @@ urlFallbackLoop:
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
requestID
,
Kind
:
"failover"
,
...
...
@@ -1559,6 +1574,7 @@ urlFallbackLoop:
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
requestID
,
Kind
:
"http_error"
,
...
...
@@ -2039,6 +2055,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
upstreamStatus
,
UpstreamRequestID
:
upstreamRequestID
,
Kind
:
"http_error"
,
...
...
backend/internal/service/antigravity_quota_scope.go
View file @
6901b64f
...
...
@@ -49,6 +49,9 @@ func (a *Account) IsSchedulableForModel(requestedModel string) bool {
if
!
a
.
IsSchedulable
()
{
return
false
}
if
a
.
isModelRateLimited
(
requestedModel
)
{
return
false
}
if
a
.
Platform
!=
PlatformAntigravity
{
return
true
}
...
...
backend/internal/service/antigravity_token_provider.go
View file @
6901b64f
...
...
@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return
""
,
errors
.
New
(
"not an antigravity oauth account"
)
}
cacheKey
:=
a
ntigravityTokenCacheKey
(
account
)
cacheKey
:=
A
ntigravityTokenCacheKey
(
account
)
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
...
...
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return
accessToken
,
nil
}
func
a
ntigravityTokenCacheKey
(
account
*
Account
)
string
{
func
A
ntigravityTokenCacheKey
(
account
*
Account
)
string
{
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
projectID
!=
""
{
return
"ag:"
+
projectID
...
...
backend/internal/service/api_key_auth_cache.go
View file @
6901b64f
...
...
@@ -37,6 +37,11 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice4K
*
float64
`json:"image_price_4k,omitempty"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting
map
[
string
][]
int64
`json:"model_routing,omitempty"`
ModelRoutingEnabled
bool
`json:"model_routing_enabled"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
...
...
backend/internal/service/api_key_auth_cache_impl.go
View file @
6901b64f
...
...
@@ -221,6 +221,8 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ImagePrice4K
:
apiKey
.
Group
.
ImagePrice4K
,
ClaudeCodeOnly
:
apiKey
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
apiKey
.
Group
.
FallbackGroupID
,
ModelRouting
:
apiKey
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
apiKey
.
Group
.
ModelRoutingEnabled
,
}
}
return
snapshot
...
...
@@ -263,6 +265,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ImagePrice4K
:
snapshot
.
Group
.
ImagePrice4K
,
ClaudeCodeOnly
:
snapshot
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
snapshot
.
Group
.
FallbackGroupID
,
ModelRouting
:
snapshot
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
snapshot
.
Group
.
ModelRoutingEnabled
,
}
}
return
apiKey
...
...
backend/internal/service/api_key_service_cache_test.go
View file @
6901b64f
...
...
@@ -178,6 +178,10 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
Status
:
StatusActive
,
SubscriptionType
:
SubscriptionTypeStandard
,
RateMultiplier
:
1
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-opus-*"
:
{
1
,
2
},
},
},
},
}
...
...
@@ -190,6 +194,8 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
require
.
Equal
(
t
,
int64
(
1
),
apiKey
.
ID
)
require
.
Equal
(
t
,
int64
(
2
),
apiKey
.
User
.
ID
)
require
.
Equal
(
t
,
groupID
,
apiKey
.
Group
.
ID
)
require
.
True
(
t
,
apiKey
.
Group
.
ModelRoutingEnabled
)
require
.
Equal
(
t
,
map
[
string
][]
int64
{
"claude-opus-*"
:
{
1
,
2
}},
apiKey
.
Group
.
ModelRouting
)
}
func
TestAPIKeyService_GetByKey_NegativeCache
(
t
*
testing
.
T
)
{
...
...
backend/internal/service/claude_token_provider.go
0 → 100644
View file @
6901b64f
package
service
import
(
"context"
"errors"
"log/slog"
"strconv"
"strings"
"time"
)
const
(
claudeTokenRefreshSkew
=
3
*
time
.
Minute
claudeTokenCacheSkew
=
5
*
time
.
Minute
claudeLockWaitTime
=
200
*
time
.
Millisecond
)
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type
ClaudeTokenCache
=
GeminiTokenCache
// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
type
ClaudeTokenProvider
struct
{
accountRepo
AccountRepository
tokenCache
ClaudeTokenCache
oauthService
*
OAuthService
}
func
NewClaudeTokenProvider
(
accountRepo
AccountRepository
,
tokenCache
ClaudeTokenCache
,
oauthService
*
OAuthService
,
)
*
ClaudeTokenProvider
{
return
&
ClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
tokenCache
,
oauthService
:
oauthService
,
}
}
// GetAccessToken 获取有效的 access_token
func
(
p
*
ClaudeTokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformAnthropic
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an anthropic oauth account"
)
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"claude_token_cache_hit"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
else
if
err
!=
nil
{
slog
.
Warn
(
"claude_token_cache_get_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
slog
.
Debug
(
"claude_token_cache_miss"
,
"account_id"
,
account
.
ID
)
// 2. 如果即将过期则刷新
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
lockErr
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
lockErr
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
return
token
,
nil
}
// 从数据库获取最新账户信息
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
slog
.
Warn
(
"claude_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog
.
Warn
(
"claude_token_refresh_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
// 构建新 credentials,保留原有字段
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_in"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresIn
,
10
)
newCredentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
Scope
!=
""
{
newCredentials
[
"scope"
]
=
tokenInfo
.
Scope
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"claude_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
lockErr
!=
nil
{
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
slog
.
Warn
(
"claude_token_lock_failed_degraded_refresh"
,
"account_id"
,
account
.
ID
,
"error"
,
lockErr
)
// 检查 ctx 是否已取消
if
ctx
.
Err
()
!=
nil
{
return
""
,
ctx
.
Err
()
}
// 从数据库获取最新账户信息
if
p
.
accountRepo
!=
nil
{
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
slog
.
Warn
(
"claude_oauth_service_not_configured"
,
"account_id"
,
account
.
ID
)
refreshFailed
=
true
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
slog
.
Warn
(
"claude_token_refresh_failed_degraded"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
refreshFailed
=
true
}
else
{
// 构建新 credentials,保留原有字段
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_in"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresIn
,
10
)
newCredentials
[
"expires_at"
]
=
strconv
.
FormatInt
(
tokenInfo
.
ExpiresAt
,
10
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
if
tokenInfo
.
Scope
!=
""
{
newCredentials
[
"scope"
]
=
tokenInfo
.
Scope
}
account
.
Credentials
=
newCredentials
if
updateErr
:=
p
.
accountRepo
.
Update
(
ctx
,
account
);
updateErr
!=
nil
{
slog
.
Error
(
"claude_token_provider_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
updateErr
)
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
{
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time
.
Sleep
(
claudeLockWaitTime
)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
strings
.
TrimSpace
(
token
)
!=
""
{
slog
.
Debug
(
"claude_token_cache_hit_after_wait"
,
"account_id"
,
account
.
ID
)
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
strings
.
TrimSpace
(
accessToken
)
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. 存入缓存
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
ttl
=
time
.
Minute
slog
.
Debug
(
"claude_token_cache_short_ttl"
,
"account_id"
,
account
.
ID
,
"reason"
,
"refresh_failed"
)
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
switch
{
case
until
>
claudeTokenCacheSkew
:
ttl
=
until
-
claudeTokenCacheSkew
case
until
>
0
:
ttl
=
until
default
:
ttl
=
time
.
Minute
}
}
if
err
:=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
);
err
!=
nil
{
slog
.
Warn
(
"claude_token_cache_set_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
return
accessToken
,
nil
}
backend/internal/service/claude_token_provider_test.go
0 → 100644
View file @
6901b64f
//go:build unit
package
service
import
(
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// claudeTokenCacheStub implements ClaudeTokenCache for testing
type
claudeTokenCacheStub
struct
{
mu
sync
.
Mutex
tokens
map
[
string
]
string
getErr
error
setErr
error
deleteErr
error
lockAcquired
bool
lockErr
error
releaseLockErr
error
getCalled
int32
setCalled
int32
lockCalled
int32
unlockCalled
int32
simulateLockRace
bool
}
func
newClaudeTokenCacheStub
()
*
claudeTokenCacheStub
{
return
&
claudeTokenCacheStub
{
tokens
:
make
(
map
[
string
]
string
),
lockAcquired
:
true
,
}
}
func
(
s
*
claudeTokenCacheStub
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
{
atomic
.
AddInt32
(
&
s
.
getCalled
,
1
)
if
s
.
getErr
!=
nil
{
return
""
,
s
.
getErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
return
s
.
tokens
[
cacheKey
],
nil
}
func
(
s
*
claudeTokenCacheStub
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
{
atomic
.
AddInt32
(
&
s
.
setCalled
,
1
)
if
s
.
setErr
!=
nil
{
return
s
.
setErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
s
.
tokens
[
cacheKey
]
=
token
return
nil
}
func
(
s
*
claudeTokenCacheStub
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
if
s
.
deleteErr
!=
nil
{
return
s
.
deleteErr
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
delete
(
s
.
tokens
,
cacheKey
)
return
nil
}
func
(
s
*
claudeTokenCacheStub
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
atomic
.
AddInt32
(
&
s
.
lockCalled
,
1
)
if
s
.
lockErr
!=
nil
{
return
false
,
s
.
lockErr
}
if
s
.
simulateLockRace
{
return
false
,
nil
}
return
s
.
lockAcquired
,
nil
}
func
(
s
*
claudeTokenCacheStub
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
atomic
.
AddInt32
(
&
s
.
unlockCalled
,
1
)
return
s
.
releaseLockErr
}
// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
type
claudeAccountRepoStub
struct
{
account
*
Account
getErr
error
updateErr
error
getCalled
int32
updateCalled
int32
}
func
(
r
*
claudeAccountRepoStub
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Account
,
error
)
{
atomic
.
AddInt32
(
&
r
.
getCalled
,
1
)
if
r
.
getErr
!=
nil
{
return
nil
,
r
.
getErr
}
return
r
.
account
,
nil
}
func
(
r
*
claudeAccountRepoStub
)
Update
(
ctx
context
.
Context
,
account
*
Account
)
error
{
atomic
.
AddInt32
(
&
r
.
updateCalled
,
1
)
if
r
.
updateErr
!=
nil
{
return
r
.
updateErr
}
r
.
account
=
account
return
nil
}
// claudeOAuthServiceStub implements OAuthService methods for testing
type
claudeOAuthServiceStub
struct
{
tokenInfo
*
TokenInfo
refreshErr
error
refreshCalled
int32
}
func
(
s
*
claudeOAuthServiceStub
)
RefreshAccountToken
(
ctx
context
.
Context
,
account
*
Account
)
(
*
TokenInfo
,
error
)
{
atomic
.
AddInt32
(
&
s
.
refreshCalled
,
1
)
if
s
.
refreshErr
!=
nil
{
return
nil
,
s
.
refreshErr
}
return
s
.
tokenInfo
,
nil
}
// testClaudeTokenProvider is a test version that uses the stub OAuth service
type
testClaudeTokenProvider
struct
{
accountRepo
*
claudeAccountRepoStub
tokenCache
*
claudeTokenCacheStub
oauthService
*
claudeOAuthServiceStub
}
func
(
p
*
testClaudeTokenProvider
)
GetAccessToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
error
)
{
if
account
==
nil
{
return
""
,
errors
.
New
(
"account is nil"
)
}
if
account
.
Platform
!=
PlatformAnthropic
||
account
.
Type
!=
AccountTypeOAuth
{
return
""
,
errors
.
New
(
"not an anthropic oauth account"
)
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// 1. Check cache
if
p
.
tokenCache
!=
nil
{
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
}
// 2. Check if refresh needed
expiresAt
:=
account
.
GetCredentialAsTime
(
"expires_at"
)
needsRefresh
:=
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
refreshFailed
:=
false
if
needsRefresh
&&
p
.
tokenCache
!=
nil
{
locked
,
err
:=
p
.
tokenCache
.
AcquireRefreshLock
(
ctx
,
cacheKey
,
30
*
time
.
Second
)
if
err
==
nil
&&
locked
{
defer
func
()
{
_
=
p
.
tokenCache
.
ReleaseRefreshLock
(
ctx
,
cacheKey
)
}()
// Check cache again after acquiring lock
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
// Get fresh account from DB
fresh
,
err
:=
p
.
accountRepo
.
GetByID
(
ctx
,
account
.
ID
)
if
err
==
nil
&&
fresh
!=
nil
{
account
=
fresh
}
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
if
expiresAt
==
nil
||
time
.
Until
(
*
expiresAt
)
<=
claudeTokenRefreshSkew
{
if
p
.
oauthService
==
nil
{
refreshFailed
=
true
// 无法刷新,标记失败
}
else
{
tokenInfo
,
err
:=
p
.
oauthService
.
RefreshAccountToken
(
ctx
,
account
)
if
err
!=
nil
{
refreshFailed
=
true
// 刷新失败,标记以使用短 TTL
}
else
{
// Build new credentials
newCredentials
:=
make
(
map
[
string
]
any
)
for
k
,
v
:=
range
account
.
Credentials
{
newCredentials
[
k
]
=
v
}
newCredentials
[
"access_token"
]
=
tokenInfo
.
AccessToken
newCredentials
[
"token_type"
]
=
tokenInfo
.
TokenType
newCredentials
[
"expires_at"
]
=
time
.
Now
()
.
Add
(
time
.
Duration
(
tokenInfo
.
ExpiresIn
)
*
time
.
Second
)
.
Format
(
time
.
RFC3339
)
if
tokenInfo
.
RefreshToken
!=
""
{
newCredentials
[
"refresh_token"
]
=
tokenInfo
.
RefreshToken
}
account
.
Credentials
=
newCredentials
_
=
p
.
accountRepo
.
Update
(
ctx
,
account
)
expiresAt
=
account
.
GetCredentialAsTime
(
"expires_at"
)
}
}
}
}
else
if
p
.
tokenCache
.
simulateLockRace
{
// Wait and retry cache
time
.
Sleep
(
10
*
time
.
Millisecond
)
if
token
,
err
:=
p
.
tokenCache
.
GetAccessToken
(
ctx
,
cacheKey
);
err
==
nil
&&
token
!=
""
{
return
token
,
nil
}
}
}
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
accessToken
==
""
{
return
""
,
errors
.
New
(
"access_token not found in credentials"
)
}
// 3. Store in cache
if
p
.
tokenCache
!=
nil
{
ttl
:=
30
*
time
.
Minute
if
refreshFailed
{
ttl
=
time
.
Minute
// 刷新失败时使用短 TTL
}
else
if
expiresAt
!=
nil
{
until
:=
time
.
Until
(
*
expiresAt
)
if
until
>
claudeTokenCacheSkew
{
ttl
=
until
-
claudeTokenCacheSkew
}
else
if
until
>
0
{
ttl
=
until
}
else
{
ttl
=
time
.
Minute
}
}
_
=
p
.
tokenCache
.
SetAccessToken
(
ctx
,
cacheKey
,
accessToken
,
ttl
)
}
return
accessToken
,
nil
}
func
TestClaudeTokenProvider_CacheHit
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
account
:=
&
Account
{
ID
:
100
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"db-token"
,
},
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
cache
.
tokens
[
cacheKey
]
=
"cached-token"
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"cached-token"
,
token
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
cache
.
getCalled
))
require
.
Equal
(
t
,
int32
(
0
),
atomic
.
LoadInt32
(
&
cache
.
setCalled
))
}
func
TestClaudeTokenProvider_CacheMiss_FromCredentials
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
// Token expires in far future, no refresh needed
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
101
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"credential-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"credential-token"
,
token
)
// Should have stored in cache
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
require
.
Equal
(
t
,
"credential-token"
,
cache
.
tokens
[
cacheKey
])
}
func
TestClaudeTokenProvider_TokenRefresh
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh-token"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
ExpiresAt
:
time
.
Now
()
.
Add
(
time
.
Hour
)
.
Unix
(),
},
}
// Token expires soon (within refresh skew)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
102
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"refreshed-token"
,
token
)
require
.
Equal
(
t
,
int32
(
1
),
atomic
.
LoadInt32
(
&
oauthService
.
refreshCalled
))
}
func
TestClaudeTokenProvider_LockRaceCondition
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
simulateLockRace
=
true
accountRepo
:=
&
claudeAccountRepoStub
{}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
103
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"race-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
// Simulate another worker already refreshed and cached
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
go
func
()
{
time
.
Sleep
(
5
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"winner-token"
cache
.
mu
.
Unlock
()
}()
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
func
TestClaudeTokenProvider_NilAccount
(
t
*
testing
.
T
)
{
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
nil
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"account is nil"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_WrongPlatform
(
t
*
testing
.
T
)
{
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
104
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_WrongAccountType
(
t
*
testing
.
T
)
{
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
105
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_SetupTokenType
(
t
*
testing
.
T
)
{
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
account
:=
&
Account
{
ID
:
106
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeSetupToken
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"not an anthropic oauth account"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_NilCache
(
t
*
testing
.
T
)
{
// Token doesn't need refresh
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
107
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"nocache-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
nil
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"nocache-token"
,
token
)
}
func
TestClaudeTokenProvider_CacheGetError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
getErr
=
errors
.
New
(
"redis connection failed"
)
// Token doesn't need refresh
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
108
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
// Should gracefully degrade and return from credentials
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"fallback-token"
,
token
)
}
func
TestClaudeTokenProvider_CacheSetError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
setErr
=
errors
.
New
(
"redis write failed"
)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
109
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"still-works-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
// Should still work even if cache set fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"still-works-token"
,
token
)
}
func
TestClaudeTokenProvider_MissingAccessToken
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
110
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"expires_at"
:
expiresAt
,
// missing access_token
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_RefreshError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
oauthService
:=
&
claudeOAuthServiceStub
{
refreshErr
:
errors
.
New
(
"oauth refresh failed"
),
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
111
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// Now with fallback behavior, should return existing token even if refresh fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"old-token"
,
token
)
// Fallback to existing token
}
func
TestClaudeTokenProvider_OAuthServiceNotConfigured
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
112
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
nil
,
// not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"old-token"
,
token
)
// Fallback to existing token
}
func
TestClaudeTokenProvider_TTLCalculation
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
expiresIn
time
.
Duration
}{
{
name
:
"far_future_expiry"
,
expiresIn
:
1
*
time
.
Hour
,
},
{
name
:
"medium_expiry"
,
expiresIn
:
10
*
time
.
Minute
,
},
{
name
:
"near_expiry"
,
expiresIn
:
6
*
time
.
Minute
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
tt
.
expiresIn
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
200
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
_
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
// Verify token was cached
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
require
.
Equal
(
t
,
"test-token"
,
cache
.
tokens
[
cacheKey
])
})
}
}
func
TestClaudeTokenProvider_AccountRepoGetError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{
getErr
:
errors
.
New
(
"db connection failed"
),
}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
113
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// Should still work, just using the passed-in account
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"refreshed-token"
,
token
)
}
func
TestClaudeTokenProvider_AccountUpdateError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{
updateErr
:
errors
.
New
(
"db write failed"
),
}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
114
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"refresh_token"
:
"old-refresh"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
// Should still return token even if update fails
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"refreshed-token"
,
token
)
}
func
TestClaudeTokenProvider_RefreshPreservesExistingCredentials
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"new-access-token"
,
RefreshToken
:
"new-refresh-token"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
115
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-access-token"
,
"refresh_token"
:
"old-refresh-token"
,
"expires_at"
:
expiresAt
,
"custom_field"
:
"should-be-preserved"
,
"organization"
:
"test-org"
,
},
}
accountRepo
.
account
=
account
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"new-access-token"
,
token
)
// Verify existing fields are preserved
require
.
Equal
(
t
,
"should-be-preserved"
,
accountRepo
.
account
.
Credentials
[
"custom_field"
])
require
.
Equal
(
t
,
"test-org"
,
accountRepo
.
account
.
Credentials
[
"organization"
])
// Verify new fields are updated
require
.
Equal
(
t
,
"new-access-token"
,
accountRepo
.
account
.
Credentials
[
"access_token"
])
require
.
Equal
(
t
,
"new-refresh-token"
,
accountRepo
.
account
.
Credentials
[
"refresh_token"
])
}
func
TestClaudeTokenProvider_DoubleCheckCacheAfterLock
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
accountRepo
:=
&
claudeAccountRepoStub
{}
oauthService
:=
&
claudeOAuthServiceStub
{
tokenInfo
:
&
TokenInfo
{
AccessToken
:
"refreshed-token"
,
RefreshToken
:
"new-refresh"
,
TokenType
:
"Bearer"
,
ExpiresIn
:
3600
,
},
}
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
116
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"old-token"
,
"expires_at"
:
expiresAt
,
},
}
accountRepo
.
account
=
account
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// After lock is acquired, cache should have the token (simulating another worker)
go
func
()
{
time
.
Sleep
(
5
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"cached-by-other-worker"
cache
.
mu
.
Unlock
()
}()
provider
:=
&
testClaudeTokenProvider
{
accountRepo
:
accountRepo
,
tokenCache
:
cache
,
oauthService
:
oauthService
,
}
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
// Tests for real provider - to increase coverage
func
TestClaudeTokenProvider_Real_LockFailedWait
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
lockAcquired
=
false
// Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
300
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-token"
,
"expires_at"
:
expiresAt
,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
go
func
()
{
time
.
Sleep
(
100
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"refreshed-by-other"
cache
.
mu
.
Unlock
()
}()
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
func
TestClaudeTokenProvider_Real_CacheHitAfterWait
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
lockAcquired
=
false
// Lock acquisition fails
// Token expires soon
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
301
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"original-token"
,
"expires_at"
:
expiresAt
,
},
}
cacheKey
:=
ClaudeTokenCacheKey
(
account
)
// Set token in cache immediately after wait starts
go
func
()
{
time
.
Sleep
(
50
*
time
.
Millisecond
)
cache
.
mu
.
Lock
()
cache
.
tokens
[
cacheKey
]
=
"winner-token"
cache
.
mu
.
Unlock
()
}()
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
NotEmpty
(
t
,
token
)
}
func
TestClaudeTokenProvider_Real_NoExpiresAt
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
lockAcquired
=
false
// Prevent entering refresh logic
// Token with nil expires_at (no expiry set)
account
:=
&
Account
{
ID
:
302
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"no-expiry-token"
,
},
}
// After lock wait, return token from credentials
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"no-expiry-token"
,
token
)
}
func
TestClaudeTokenProvider_Real_WhitespaceToken
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cacheKey
:=
"claude:account:303"
cache
.
tokens
[
cacheKey
]
=
" "
// Whitespace only - should be treated as empty
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
303
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"real-token"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"real-token"
,
token
)
}
func
TestClaudeTokenProvider_Real_EmptyCredentialToken
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
304
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
" "
,
// Whitespace only
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
func
TestClaudeTokenProvider_Real_LockError
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
cache
.
lockErr
=
errors
.
New
(
"redis lock failed"
)
// Token expires soon (within refresh skew)
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Minute
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
305
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"fallback-on-lock-error"
,
"expires_at"
:
expiresAt
,
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"fallback-on-lock-error"
,
token
)
}
func
TestClaudeTokenProvider_Real_NilCredentials
(
t
*
testing
.
T
)
{
cache
:=
newClaudeTokenCacheStub
()
expiresAt
:=
time
.
Now
()
.
Add
(
1
*
time
.
Hour
)
.
Format
(
time
.
RFC3339
)
account
:=
&
Account
{
ID
:
306
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"expires_at"
:
expiresAt
,
// No access_token
},
}
provider
:=
NewClaudeTokenProvider
(
nil
,
cache
,
nil
)
token
,
err
:=
provider
.
GetAccessToken
(
context
.
Background
(),
account
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"access_token not found"
)
require
.
Empty
(
t
,
token
)
}
backend/internal/service/dashboard_service.go
View file @
6901b64f
...
...
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return
stats
,
nil
}
func
(
s
*
DashboardService
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
int64
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
trend
,
err
:=
s
.
usageRepo
.
GetUsageTrendWithFilters
(
ctx
,
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
)
func
(
s
*
DashboardService
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
trend
,
err
:=
s
.
usageRepo
.
GetUsageTrendWithFilters
(
ctx
,
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
stream
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get usage trend with filters: %w"
,
err
)
}
return
trend
,
nil
}
func
(
s
*
DashboardService
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
int64
)
([]
usagestats
.
ModelStat
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
userID
,
apiKeyID
,
0
)
func
(
s
*
DashboardService
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
stream
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get model stats with filters: %w"
,
err
)
}
...
...
backend/internal/service/gateway_multiplatform_test.go
View file @
6901b64f
...
...
@@ -142,6 +142,9 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
func
(
m
*
mockAccountRepoForPlatform
)
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
return
nil
}
...
...
@@ -157,6 +160,9 @@ func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int6
func
(
m
*
mockAccountRepoForPlatform
)
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
ClearModelRateLimits
(
ctx
context
.
Context
,
id
int64
)
error
{
return
nil
}
func
(
m
*
mockAccountRepoForPlatform
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
return
nil
}
...
...
@@ -1046,13 +1052,67 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
// No concurrency service
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
1
),
result
.
Account
.
ID
,
"应选择优先级最高的账号"
)
})
t
.
Run
(
"模型路由-无ConcurrencyService也生效"
,
func
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
sessionHash
:=
"sticky"
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}},
{
ID
:
2
,
Platform
:
PlatformAnthropic
,
Priority
:
1
,
Status
:
StatusActive
,
Schedulable
:
true
,
Concurrency
:
5
,
AccountGroups
:
[]
AccountGroup
{{
GroupID
:
groupID
}}},
},
accountsByID
:
map
[
int64
]
*
Account
{},
}
for
i
:=
range
repo
.
accounts
{
repo
.
accountsByID
[
repo
.
accounts
[
i
]
.
ID
]
=
&
repo
.
accounts
[
i
]
}
cache
:=
&
mockGatewayCacheForPlatform
{
sessionBindings
:
map
[
string
]
int64
{
sessionHash
:
1
},
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
ModelRoutingEnabled
:
true
,
ModelRouting
:
map
[
string
][]
int64
{
"claude-a"
:
{
1
},
"claude-b"
:
{
2
},
},
},
},
}
cfg
:=
testConfig
()
cfg
.
Gateway
.
Scheduling
.
LoadBatchEnabled
=
true
svc
:=
&
GatewayService
{
accountRepo
:
repo
,
groupRepo
:
groupRepo
,
cache
:
cache
,
cfg
:
cfg
,
concurrencyService
:
nil
,
// legacy path
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
&
groupID
,
sessionHash
,
"claude-b"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
require
.
Equal
(
t
,
int64
(
2
),
result
.
Account
.
ID
,
"切换到 claude-b 时应按模型路由切换账号"
)
require
.
Equal
(
t
,
int64
(
2
),
cache
.
sessionBindings
[
sessionHash
],
"粘性绑定应更新为路由选择的账号"
)
})
t
.
Run
(
"无ConcurrencyService-降级到传统选择"
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
mockAccountRepoForPlatform
{
accounts
:
[]
Account
{
...
...
@@ -1077,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
...
...
@@ -1109,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
}
excludedIDs
:=
map
[
int64
]
struct
{}{
1
:
{}}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
excludedIDs
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
...
...
@@ -1143,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
...
...
@@ -1179,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
NewConcurrencyService
(
concurrencyCache
),
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
"sticky"
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
...
...
@@ -1206,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
Error
(
t
,
err
)
require
.
Nil
(
t
,
result
)
require
.
Contains
(
t
,
err
.
Error
(),
"no available accounts"
)
...
...
@@ -1238,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
...
...
@@ -1271,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService
:
nil
,
}
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
)
result
,
err
:=
svc
.
SelectAccountWithLoadAwareness
(
ctx
,
nil
,
""
,
"claude-3-5-sonnet-20241022"
,
nil
,
""
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
NotNil
(
t
,
result
.
Account
)
...
...
@@ -1341,6 +1401,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T)
ID
:
groupID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
groupRepo
:=
&
mockGroupRepoForGateway
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
group
},
...
...
@@ -1398,6 +1459,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
ID
:
fallbackID
,
Platform
:
PlatformAnthropic
,
Status
:
StatusActive
,
Hydrated
:
true
,
}
ctx
=
context
.
WithValue
(
ctx
,
ctxkey
.
Group
,
group
)
...
...
backend/internal/service/gateway_service.go
View file @
6901b64f
...
...
@@ -12,6 +12,7 @@ import (
"io"
"log"
"net/http"
"os"
"regexp"
"sort"
"strings"
...
...
@@ -40,6 +41,21 @@ const (
maxCacheControlBlocks
=
4
// Anthropic API 允许的最大 cache_control 块数量
)
func
(
s
*
GatewayService
)
debugModelRoutingEnabled
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
"SUB2API_DEBUG_MODEL_ROUTING"
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
shortSessionHash
(
sessionHash
string
)
string
{
if
sessionHash
==
""
{
return
""
}
if
len
(
sessionHash
)
<=
8
{
return
sessionHash
}
return
sessionHash
[
:
8
]
}
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var
(
...
...
@@ -196,6 +212,8 @@ type GatewayService struct {
httpUpstream
HTTPUpstream
deferredService
*
DeferredService
concurrencyService
*
ConcurrencyService
claudeTokenProvider
*
ClaudeTokenProvider
sessionLimitCache
SessionLimitCache
// 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
}
// NewGatewayService creates a new GatewayService
...
...
@@ -215,6 +233,8 @@ func NewGatewayService(
identityService
*
IdentityService
,
httpUpstream
HTTPUpstream
,
deferredService
*
DeferredService
,
claudeTokenProvider
*
ClaudeTokenProvider
,
sessionLimitCache
SessionLimitCache
,
)
*
GatewayService
{
return
&
GatewayService
{
accountRepo
:
accountRepo
,
...
...
@@ -232,6 +252,8 @@ func NewGatewayService(
identityService
:
identityService
,
httpUpstream
:
httpUpstream
,
deferredService
:
deferredService
,
claudeTokenProvider
:
claudeTokenProvider
,
sessionLimitCache
:
sessionLimitCache
,
}
}
...
...
@@ -797,8 +819,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
func
(
s
*
GatewayService
)
SelectAccountWithLoadAwareness
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{})
(
*
AccountSelectionResult
,
error
)
{
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
func
(
s
*
GatewayService
)
SelectAccountWithLoadAwareness
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
metadataUserID
string
)
(
*
AccountSelectionResult
,
error
)
{
cfg
:=
s
.
schedulingConfig
()
// 提取会话 UUID(用于会话数量限制)
sessionUUID
:=
extractSessionUUID
(
metadataUserID
)
var
stickyAccountID
int64
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
);
err
==
nil
{
...
...
@@ -813,6 +839,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
ctx
=
s
.
withGroupContext
(
ctx
,
group
)
if
s
.
debugModelRoutingEnabled
()
&&
requestedModel
!=
""
{
groupPlatform
:=
""
if
group
!=
nil
{
groupPlatform
=
group
.
Platform
}
log
.
Printf
(
"[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v"
,
derefGroupID
(
groupID
),
groupPlatform
,
requestedModel
,
shortSessionHash
(
sessionHash
),
stickyAccountID
,
cfg
.
LoadBatchEnabled
,
s
.
concurrencyService
!=
nil
)
}
if
s
.
concurrencyService
==
nil
||
!
cfg
.
LoadBatchEnabled
{
account
,
err
:=
s
.
SelectAccountForModelWithExclusions
(
ctx
,
groupID
,
sessionHash
,
requestedModel
,
excludedIDs
)
if
err
!=
nil
{
...
...
@@ -856,6 +891,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return
nil
,
err
}
preferOAuth
:=
platform
==
PlatformGemini
if
s
.
debugModelRoutingEnabled
()
&&
platform
==
PlatformAnthropic
&&
requestedModel
!=
""
{
log
.
Printf
(
"[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
platform
)
}
accounts
,
useMixed
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
...
...
@@ -873,22 +911,235 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return
excluded
}
// ============ Layer 1: 粘性会话优先 ============
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
!
isExcluded
(
accountID
)
{
// 粘性命中仅在当前可调度候选集中生效。
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
accountByID
:=
make
(
map
[
int64
]
*
Account
,
len
(
accounts
))
for
i
:=
range
accounts
{
accountByID
[
accounts
[
i
]
.
ID
]
=
&
accounts
[
i
]
}
// 获取模型路由配置(仅 anthropic 平台)
var
routingAccountIDs
[]
int64
if
group
!=
nil
&&
requestedModel
!=
""
&&
group
.
Platform
==
PlatformAnthropic
{
routingAccountIDs
=
group
.
GetRoutingAccountIDs
(
requestedModel
)
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d"
,
group
.
ID
,
requestedModel
,
group
.
ModelRoutingEnabled
,
len
(
group
.
ModelRouting
),
routingAccountIDs
,
shortSessionHash
(
sessionHash
),
stickyAccountID
)
if
len
(
routingAccountIDs
)
==
0
&&
group
.
ModelRoutingEnabled
&&
len
(
group
.
ModelRouting
)
>
0
{
keys
:=
make
([]
string
,
0
,
len
(
group
.
ModelRouting
))
for
k
:=
range
group
.
ModelRouting
{
keys
=
append
(
keys
,
k
)
}
sort
.
Strings
(
keys
)
const
maxKeys
=
20
if
len
(
keys
)
>
maxKeys
{
keys
=
keys
[
:
maxKeys
]
}
log
.
Printf
(
"[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v"
,
group
.
ID
,
requestedModel
,
keys
)
}
}
}
// ============ Layer 1: 模型路由优先选择(优先级高于粘性会话) ============
if
len
(
routingAccountIDs
)
>
0
&&
s
.
concurrencyService
!=
nil
{
// 1. 过滤出路由列表中可调度的账号
var
routingCandidates
[]
*
Account
var
filteredExcluded
,
filteredMissing
,
filteredUnsched
,
filteredPlatform
,
filteredModelScope
,
filteredModelMapping
,
filteredWindowCost
int
for
_
,
routingAccountID
:=
range
routingAccountIDs
{
if
isExcluded
(
routingAccountID
)
{
filteredExcluded
++
continue
}
account
,
ok
:=
accountByID
[
routingAccountID
]
if
!
ok
||
!
account
.
IsSchedulable
()
{
if
!
ok
{
filteredMissing
++
}
else
{
filteredUnsched
++
}
continue
}
if
!
s
.
isAccountAllowedForPlatform
(
account
,
platform
,
useMixed
)
{
filteredPlatform
++
continue
}
if
!
account
.
IsSchedulableForModel
(
requestedModel
)
{
filteredModelScope
++
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
account
,
requestedModel
)
{
filteredModelMapping
++
continue
}
// 窗口费用检查(非粘性会话路径)
if
!
s
.
isAccountSchedulableForWindowCost
(
ctx
,
account
,
false
)
{
filteredWindowCost
++
continue
}
routingCandidates
=
append
(
routingCandidates
,
account
)
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)"
,
derefGroupID
(
groupID
),
requestedModel
,
len
(
routingAccountIDs
),
len
(
routingCandidates
),
filteredExcluded
,
filteredMissing
,
filteredUnsched
,
filteredPlatform
,
filteredModelScope
,
filteredModelMapping
,
filteredWindowCost
)
}
if
len
(
routingCandidates
)
>
0
{
// 1.5. 在路由账号范围内检查粘性会话
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
stickyAccountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
stickyAccountID
>
0
&&
containsInt64
(
routingAccountIDs
,
stickyAccountID
)
&&
!
isExcluded
(
stickyAccountID
)
{
// 粘性账号在路由列表中,优先使用
if
stickyAccount
,
ok
:=
accountByID
[
stickyAccountID
];
ok
{
if
stickyAccount
.
IsSchedulable
()
&&
s
.
isAccountAllowedForPlatform
(
stickyAccount
,
platform
,
useMixed
)
&&
stickyAccount
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
stickyAccount
,
requestedModel
))
&&
s
.
isAccountSchedulableForWindowCost
(
ctx
,
stickyAccount
,
true
)
{
// 粘性会话窗口费用检查
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
stickyAccountID
,
stickyAccount
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
stickyAccount
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位
// 继续到负载感知选择
}
else
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
stickyAccountID
)
}
return
&
AccountSelectionResult
{
Account
:
stickyAccount
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
stickyAccountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
return
&
AccountSelectionResult
{
Account
:
stickyAccount
,
WaitPlan
:
&
AccountWaitPlan
{
AccountID
:
stickyAccountID
,
MaxConcurrency
:
stickyAccount
.
Concurrency
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
}
}
}
// 2. 批量获取负载信息
routingLoads
:=
make
([]
AccountWithConcurrency
,
0
,
len
(
routingCandidates
))
for
_
,
acc
:=
range
routingCandidates
{
routingLoads
=
append
(
routingLoads
,
AccountWithConcurrency
{
ID
:
acc
.
ID
,
MaxConcurrency
:
acc
.
Concurrency
,
})
}
routingLoadMap
,
_
:=
s
.
concurrencyService
.
GetAccountsLoadBatch
(
ctx
,
routingLoads
)
// 3. 按负载感知排序
type
accountWithLoad
struct
{
account
*
Account
loadInfo
*
AccountLoadInfo
}
var
routingAvailable
[]
accountWithLoad
for
_
,
acc
:=
range
routingCandidates
{
loadInfo
:=
routingLoadMap
[
acc
.
ID
]
if
loadInfo
==
nil
{
loadInfo
=
&
AccountLoadInfo
{
AccountID
:
acc
.
ID
}
}
if
loadInfo
.
LoadRate
<
100
{
routingAvailable
=
append
(
routingAvailable
,
accountWithLoad
{
account
:
acc
,
loadInfo
:
loadInfo
})
}
}
if
len
(
routingAvailable
)
>
0
{
// 排序:优先级 > 负载率 > 最后使用时间
sort
.
SliceStable
(
routingAvailable
,
func
(
i
,
j
int
)
bool
{
a
,
b
:=
routingAvailable
[
i
],
routingAvailable
[
j
]
if
a
.
account
.
Priority
!=
b
.
account
.
Priority
{
return
a
.
account
.
Priority
<
b
.
account
.
Priority
}
if
a
.
loadInfo
.
LoadRate
!=
b
.
loadInfo
.
LoadRate
{
return
a
.
loadInfo
.
LoadRate
<
b
.
loadInfo
.
LoadRate
}
switch
{
case
a
.
account
.
LastUsedAt
==
nil
&&
b
.
account
.
LastUsedAt
!=
nil
:
return
true
case
a
.
account
.
LastUsedAt
!=
nil
&&
b
.
account
.
LastUsedAt
==
nil
:
return
false
case
a
.
account
.
LastUsedAt
==
nil
&&
b
.
account
.
LastUsedAt
==
nil
:
return
false
default
:
return
a
.
account
.
LastUsedAt
.
Before
(
*
b
.
account
.
LastUsedAt
)
}
})
// 4. 尝试获取槽位
for
_
,
item
:=
range
routingAvailable
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
item
.
account
.
ID
,
item
.
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
item
.
account
.
ID
,
stickySessionTTL
)
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
item
.
account
.
ID
)
}
return
&
AccountSelectionResult
{
Account
:
item
.
account
,
Acquired
:
true
,
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
acc
:=
routingAvailable
[
0
]
.
account
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
acc
.
ID
)
}
return
&
AccountSelectionResult
{
Account
:
acc
,
WaitPlan
:
&
AccountWaitPlan
{
AccountID
:
acc
.
ID
,
MaxConcurrency
:
acc
.
Concurrency
,
Timeout
:
cfg
.
StickySessionWaitTimeout
,
MaxWaiting
:
cfg
.
StickySessionMaxWaiting
,
},
},
nil
}
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
log
.
Printf
(
"[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection"
,
requestedModel
)
}
}
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
if
len
(
routingAccountIDs
)
==
0
&&
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
!
isExcluded
(
accountID
)
{
account
,
ok
:=
accountByID
[
accountID
]
if
ok
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
s
.
isAccountAllowedForPlatform
(
account
,
platform
,
useMixed
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
&&
s
.
isAccountSchedulableForWindowCost
(
ctx
,
account
,
true
)
{
// 粘性会话窗口费用检查
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
accountID
,
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
account
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续到 Layer 2
}
else
{
_
=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
)
return
&
AccountSelectionResult
{
Account
:
account
,
...
...
@@ -896,6 +1147,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
ReleaseFunc
:
result
.
ReleaseFunc
,
},
nil
}
}
waitingCount
,
_
:=
s
.
concurrencyService
.
GetAccountWaitingCount
(
ctx
,
accountID
)
if
waitingCount
<
cfg
.
StickySessionMaxWaiting
{
...
...
@@ -935,6 +1187,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
// 窗口费用检查(非粘性会话路径)
if
!
s
.
isAccountSchedulableForWindowCost
(
ctx
,
acc
,
false
)
{
continue
}
candidates
=
append
(
candidates
,
acc
)
}
...
...
@@ -952,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap
,
err
:=
s
.
concurrencyService
.
GetAccountsLoadBatch
(
ctx
,
accountLoads
)
if
err
!=
nil
{
if
result
,
ok
:=
s
.
tryAcquireByLegacyOrder
(
ctx
,
candidates
,
groupID
,
sessionHash
,
preferOAuth
);
ok
{
if
result
,
ok
:=
s
.
tryAcquireByLegacyOrder
(
ctx
,
candidates
,
groupID
,
sessionHash
,
preferOAuth
,
sessionUUID
);
ok
{
return
result
,
nil
}
}
else
{
...
...
@@ -1001,6 +1257,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for
_
,
item
:=
range
available
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
item
.
account
.
ID
,
item
.
account
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
item
.
account
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
item
.
account
.
ID
,
stickySessionTTL
)
}
...
...
@@ -1030,13 +1291,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return
nil
,
errors
.
New
(
"no available accounts"
)
}
func
(
s
*
GatewayService
)
tryAcquireByLegacyOrder
(
ctx
context
.
Context
,
candidates
[]
*
Account
,
groupID
*
int64
,
sessionHash
string
,
preferOAuth
bool
)
(
*
AccountSelectionResult
,
bool
)
{
func
(
s
*
GatewayService
)
tryAcquireByLegacyOrder
(
ctx
context
.
Context
,
candidates
[]
*
Account
,
groupID
*
int64
,
sessionHash
string
,
preferOAuth
bool
,
sessionUUID
string
)
(
*
AccountSelectionResult
,
bool
)
{
ordered
:=
append
([]
*
Account
(
nil
),
candidates
...
)
sortAccountsByPriorityAndLastUsed
(
ordered
,
preferOAuth
)
for
_
,
acc
:=
range
ordered
{
result
,
err
:=
s
.
tryAcquireAccountSlot
(
ctx
,
acc
.
ID
,
acc
.
Concurrency
)
if
err
==
nil
&&
result
.
Acquired
{
// 会话数量限制检查
if
!
s
.
checkAndRegisterSession
(
ctx
,
acc
,
sessionUUID
)
{
result
.
ReleaseFunc
()
// 释放槽位,继续尝试下一个账号
continue
}
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
_
=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
acc
.
ID
,
stickySessionTTL
)
}
...
...
@@ -1093,6 +1359,32 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
return
group
,
nil
}
func
(
s
*
GatewayService
)
routingAccountIDsForRequest
(
ctx
context
.
Context
,
groupID
*
int64
,
requestedModel
string
,
platform
string
)
[]
int64
{
if
groupID
==
nil
||
requestedModel
==
""
||
platform
!=
PlatformAnthropic
{
return
nil
}
group
,
err
:=
s
.
resolveGroupByID
(
ctx
,
*
groupID
)
if
err
!=
nil
||
group
==
nil
{
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v"
,
derefGroupID
(
groupID
),
requestedModel
,
platform
,
err
)
}
return
nil
}
// Preserve existing behavior: model routing only applies to anthropic groups.
if
group
.
Platform
!=
PlatformAnthropic
{
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s"
,
group
.
ID
,
group
.
Platform
,
requestedModel
)
}
return
nil
}
ids
:=
group
.
GetRoutingAccountIDs
(
requestedModel
)
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v"
,
group
.
ID
,
requestedModel
,
group
.
ModelRoutingEnabled
,
len
(
group
.
ModelRouting
),
ids
)
}
return
ids
}
func
(
s
*
GatewayService
)
resolveGatewayGroup
(
ctx
context
.
Context
,
groupID
*
int64
)
(
*
Group
,
*
int64
,
error
)
{
if
groupID
==
nil
{
return
nil
,
nil
,
nil
...
...
@@ -1242,6 +1534,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
return
s
.
concurrencyService
.
AcquireAccountSlot
(
ctx
,
accountID
,
maxConcurrency
)
}
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示可调度,false 表示不可调度
func
(
s
*
GatewayService
)
isAccountSchedulableForWindowCost
(
ctx
context
.
Context
,
account
*
Account
,
isSticky
bool
)
bool
{
// 只检查 Anthropic OAuth/SetupToken 账号
if
!
account
.
IsAnthropicOAuthOrSetupToken
()
{
return
true
}
limit
:=
account
.
GetWindowCostLimit
()
if
limit
<=
0
{
return
true
// 未启用窗口费用限制
}
// 尝试从缓存获取窗口费用
var
currentCost
float64
if
s
.
sessionLimitCache
!=
nil
{
if
cost
,
hit
,
err
:=
s
.
sessionLimitCache
.
GetWindowCost
(
ctx
,
account
.
ID
);
err
==
nil
&&
hit
{
currentCost
=
cost
goto
checkSchedulability
}
}
// 缓存未命中,从数据库查询
{
var
startTime
time
.
Time
if
account
.
SessionWindowStart
!=
nil
{
startTime
=
*
account
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
stats
,
err
:=
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
account
.
ID
,
startTime
)
if
err
!=
nil
{
// 失败开放:查询失败时允许调度
return
true
}
// 使用标准费用(不含账号倍率)
currentCost
=
stats
.
StandardCost
// 设置缓存(忽略错误)
if
s
.
sessionLimitCache
!=
nil
{
_
=
s
.
sessionLimitCache
.
SetWindowCost
(
ctx
,
account
.
ID
,
currentCost
)
}
}
checkSchedulability
:
schedulability
:=
account
.
CheckWindowCostSchedulability
(
currentCost
)
switch
schedulability
{
case
WindowCostSchedulable
:
return
true
case
WindowCostStickyOnly
:
return
isSticky
case
WindowCostNotSchedulable
:
return
false
}
return
true
}
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
func
(
s
*
GatewayService
)
checkAndRegisterSession
(
ctx
context
.
Context
,
account
*
Account
,
sessionUUID
string
)
bool
{
// 只检查 Anthropic OAuth/SetupToken 账号
if
!
account
.
IsAnthropicOAuthOrSetupToken
()
{
return
true
}
maxSessions
:=
account
.
GetMaxSessions
()
if
maxSessions
<=
0
||
sessionUUID
==
""
{
return
true
// 未启用会话限制或无会话ID
}
if
s
.
sessionLimitCache
==
nil
{
return
true
// 缓存不可用时允许通过
}
idleTimeout
:=
time
.
Duration
(
account
.
GetSessionIdleTimeoutMinutes
())
*
time
.
Minute
allowed
,
err
:=
s
.
sessionLimitCache
.
RegisterSession
(
ctx
,
account
.
ID
,
sessionUUID
,
maxSessions
,
idleTimeout
)
if
err
!=
nil
{
// 失败开放:缓存错误时允许通过
return
true
}
return
allowed
}
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
// 格式: user_{64位hex}_account__session_{uuid}
func
extractSessionUUID
(
metadataUserID
string
)
string
{
if
metadataUserID
==
""
{
return
""
}
if
match
:=
sessionIDRegex
.
FindStringSubmatch
(
metadataUserID
);
len
(
match
)
>
1
{
return
match
[
1
]
}
return
""
}
func
(
s
*
GatewayService
)
getSchedulableAccount
(
ctx
context
.
Context
,
accountID
int64
)
(
*
Account
,
error
)
{
if
s
.
schedulerSnapshot
!=
nil
{
return
s
.
schedulerSnapshot
.
GetAccount
(
ctx
,
accountID
)
...
...
@@ -1274,6 +1667,116 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func
(
s
*
GatewayService
)
selectAccountForModelWithPlatform
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
platform
string
)
(
*
Account
,
error
)
{
preferOAuth
:=
platform
==
PlatformGemini
routingAccountIDs
:=
s
.
routingAccountIDsForRequest
(
ctx
,
groupID
,
requestedModel
,
platform
)
var
accounts
[]
Account
accountsLoaded
:=
false
// ============ Model Routing (legacy path): apply before sticky session ============
// When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing
// so switching model can switch upstream account within the same sticky session.
if
len
(
routingAccountIDs
)
>
0
{
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v"
,
derefGroupID
(
groupID
),
requestedModel
,
platform
,
shortSessionHash
(
sessionHash
),
routingAccountIDs
)
}
// 1) Sticky session only applies if the bound account is within the routing set.
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
containsInt64
(
routingAccountIDs
,
accountID
)
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
Platform
==
platform
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
accountID
)
}
return
account
,
nil
}
}
}
}
// 2) Select an account from the routed candidates.
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
if
hasForcePlatform
&&
forcePlatform
==
""
{
hasForcePlatform
=
false
}
var
err
error
accounts
,
_
,
err
=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
accountsLoaded
=
true
routingSet
:=
make
(
map
[
int64
]
struct
{},
len
(
routingAccountIDs
))
for
_
,
id
:=
range
routingAccountIDs
{
if
id
>
0
{
routingSet
[
id
]
=
struct
{}{}
}
}
var
selected
*
Account
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
_
,
ok
:=
routingSet
[
acc
.
ID
];
!
ok
{
continue
}
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if
!
acc
.
IsSchedulable
()
{
continue
}
if
!
acc
.
IsSchedulableForModel
(
requestedModel
)
{
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
if
selected
==
nil
{
selected
=
acc
continue
}
if
acc
.
Priority
<
selected
.
Priority
{
selected
=
acc
}
else
if
acc
.
Priority
==
selected
.
Priority
{
switch
{
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
!=
nil
:
selected
=
acc
case
acc
.
LastUsedAt
!=
nil
&&
selected
.
LastUsedAt
==
nil
:
// keep selected (never used is preferred)
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
==
nil
:
if
preferOAuth
&&
acc
.
Type
!=
selected
.
Type
&&
acc
.
Type
==
AccountTypeOAuth
{
selected
=
acc
}
default
:
if
acc
.
LastUsedAt
.
Before
(
*
selected
.
LastUsedAt
)
{
selected
=
acc
}
}
}
}
if
selected
!=
nil
{
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
selected
.
ID
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"set session account failed: session=%s account_id=%d err=%v"
,
sessionHash
,
selected
.
ID
,
err
)
}
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
selected
.
ID
)
}
return
selected
,
nil
}
log
.
Printf
(
"[ModelRouting] No routed accounts available for model=%s, falling back to normal selection"
,
requestedModel
)
}
// 1. 查询粘性会话
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
...
...
@@ -1292,14 +1795,17 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
// 2. 获取可调度账号列表(单平台)
if
!
accountsLoaded
{
forcePlatform
,
hasForcePlatform
:=
ctx
.
Value
(
ctxkey
.
ForcePlatform
)
.
(
string
)
if
hasForcePlatform
&&
forcePlatform
==
""
{
hasForcePlatform
=
false
}
accounts
,
_
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
var
err
error
accounts
,
_
,
err
=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
platform
,
hasForcePlatform
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
}
// 3. 按优先级+最久未用选择(考虑模型支持)
var
selected
*
Account
...
...
@@ -1364,6 +1870,115 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func
(
s
*
GatewayService
)
selectAccountWithMixedScheduling
(
ctx
context
.
Context
,
groupID
*
int64
,
sessionHash
string
,
requestedModel
string
,
excludedIDs
map
[
int64
]
struct
{},
nativePlatform
string
)
(
*
Account
,
error
)
{
preferOAuth
:=
nativePlatform
==
PlatformGemini
routingAccountIDs
:=
s
.
routingAccountIDsForRequest
(
ctx
,
groupID
,
requestedModel
,
nativePlatform
)
var
accounts
[]
Account
accountsLoaded
:=
false
// ============ Model Routing (legacy path): apply before sticky session ============
if
len
(
routingAccountIDs
)
>
0
{
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v"
,
derefGroupID
(
groupID
),
requestedModel
,
nativePlatform
,
shortSessionHash
(
sessionHash
),
routingAccountIDs
)
}
// 1) Sticky session only applies if the bound account is within the routing set.
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
accountID
,
err
:=
s
.
cache
.
GetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
)
if
err
==
nil
&&
accountID
>
0
&&
containsInt64
(
routingAccountIDs
,
accountID
)
{
if
_
,
excluded
:=
excludedIDs
[
accountID
];
!
excluded
{
account
,
err
:=
s
.
getSchedulableAccount
(
ctx
,
accountID
)
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
if
err
==
nil
&&
s
.
isAccountInGroup
(
account
,
groupID
)
&&
account
.
IsSchedulableForModel
(
requestedModel
)
&&
(
requestedModel
==
""
||
s
.
isModelSupportedByAccount
(
account
,
requestedModel
))
{
if
account
.
Platform
==
nativePlatform
||
(
account
.
Platform
==
PlatformAntigravity
&&
account
.
IsMixedSchedulingEnabled
())
{
if
err
:=
s
.
cache
.
RefreshSessionTTL
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"refresh session ttl failed: session=%s err=%v"
,
sessionHash
,
err
)
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
accountID
)
}
return
account
,
nil
}
}
}
}
}
// 2) Select an account from the routed candidates.
var
err
error
accounts
,
_
,
err
=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
nativePlatform
,
false
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
accountsLoaded
=
true
routingSet
:=
make
(
map
[
int64
]
struct
{},
len
(
routingAccountIDs
))
for
_
,
id
:=
range
routingAccountIDs
{
if
id
>
0
{
routingSet
[
id
]
=
struct
{}{}
}
}
var
selected
*
Account
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
_
,
ok
:=
routingSet
[
acc
.
ID
];
!
ok
{
continue
}
if
_
,
excluded
:=
excludedIDs
[
acc
.
ID
];
excluded
{
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if
!
acc
.
IsSchedulable
()
{
continue
}
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
if
acc
.
Platform
==
PlatformAntigravity
&&
!
acc
.
IsMixedSchedulingEnabled
()
{
continue
}
if
!
acc
.
IsSchedulableForModel
(
requestedModel
)
{
continue
}
if
requestedModel
!=
""
&&
!
s
.
isModelSupportedByAccount
(
acc
,
requestedModel
)
{
continue
}
if
selected
==
nil
{
selected
=
acc
continue
}
if
acc
.
Priority
<
selected
.
Priority
{
selected
=
acc
}
else
if
acc
.
Priority
==
selected
.
Priority
{
switch
{
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
!=
nil
:
selected
=
acc
case
acc
.
LastUsedAt
!=
nil
&&
selected
.
LastUsedAt
==
nil
:
// keep selected (never used is preferred)
case
acc
.
LastUsedAt
==
nil
&&
selected
.
LastUsedAt
==
nil
:
if
preferOAuth
&&
acc
.
Platform
==
PlatformGemini
&&
selected
.
Platform
==
PlatformGemini
&&
acc
.
Type
!=
selected
.
Type
&&
acc
.
Type
==
AccountTypeOAuth
{
selected
=
acc
}
default
:
if
acc
.
LastUsedAt
.
Before
(
*
selected
.
LastUsedAt
)
{
selected
=
acc
}
}
}
}
if
selected
!=
nil
{
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
SetSessionAccountID
(
ctx
,
derefGroupID
(
groupID
),
sessionHash
,
selected
.
ID
,
stickySessionTTL
);
err
!=
nil
{
log
.
Printf
(
"set session account failed: session=%s account_id=%d err=%v"
,
sessionHash
,
selected
.
ID
,
err
)
}
}
if
s
.
debugModelRoutingEnabled
()
{
log
.
Printf
(
"[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d"
,
derefGroupID
(
groupID
),
requestedModel
,
shortSessionHash
(
sessionHash
),
selected
.
ID
)
}
return
selected
,
nil
}
log
.
Printf
(
"[ModelRouting] No routed accounts available for model=%s, falling back to normal selection"
,
requestedModel
)
}
// 1. 查询粘性会话
if
sessionHash
!=
""
&&
s
.
cache
!=
nil
{
...
...
@@ -1385,10 +2000,13 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
// 2. 获取可调度账号列表
accounts
,
_
,
err
:=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
nativePlatform
,
false
)
if
!
accountsLoaded
{
var
err
error
accounts
,
_
,
err
=
s
.
listSchedulableAccounts
(
ctx
,
groupID
,
nativePlatform
,
false
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"query accounts failed: %w"
,
err
)
}
}
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
var
selected
*
Account
...
...
@@ -1488,6 +2106,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
}
func
(
s
*
GatewayService
)
getOAuthToken
(
ctx
context
.
Context
,
account
*
Account
)
(
string
,
string
,
error
)
{
// 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token
if
account
.
Platform
==
PlatformAnthropic
&&
account
.
Type
==
AccountTypeOAuth
&&
s
.
claudeTokenProvider
!=
nil
{
accessToken
,
err
:=
s
.
claudeTokenProvider
.
GetAccessToken
(
ctx
,
account
)
if
err
!=
nil
{
return
""
,
""
,
err
}
return
accessToken
,
"oauth"
,
nil
}
// 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
accessToken
:=
account
.
GetCredential
(
"access_token"
)
if
accessToken
==
""
{
return
""
,
""
,
errors
.
New
(
"access_token not found in credentials"
)
...
...
@@ -1901,6 +2529,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart
:=
time
.
Now
()
for
attempt
:=
1
;
attempt
<=
maxRetryAttempts
;
attempt
++
{
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
// Capture upstream request body for ops retry of this attempt.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
body
))
upstreamReq
,
err
:=
s
.
buildUpstreamRequest
(
ctx
,
c
,
account
,
body
,
token
,
tokenType
,
reqModel
,
reqStream
,
shouldMimicClaudeCode
)
if
err
!=
nil
{
return
nil
,
err
...
...
@@ -1918,6 +2548,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
...
...
@@ -1942,6 +2573,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"signature_error"
,
...
...
@@ -1993,6 +2625,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
retryResp
.
StatusCode
,
UpstreamRequestID
:
retryResp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"signature_retry_thinking"
,
...
...
@@ -2021,6 +2654,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"signature_retry_tools_request_error"
,
Message
:
sanitizeUpstreamErrorMessage
(
retryErr2
.
Error
()),
...
...
@@ -2079,6 +2713,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"retry"
,
...
...
@@ -2127,6 +2762,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"retry_exhausted_failover"
,
...
...
@@ -2193,6 +2829,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"failover_on_400"
,
...
...
@@ -3283,6 +3920,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if
result
.
ImageSize
!=
""
{
imageSize
=
&
result
.
ImageSize
}
accountRateMultiplier
:=
account
.
BillingRateMultiplier
()
usageLog
:=
&
UsageLog
{
UserID
:
user
.
ID
,
APIKeyID
:
apiKey
.
ID
,
...
...
@@ -3300,6 +3938,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
TotalCost
:
cost
.
TotalCost
,
ActualCost
:
cost
.
ActualCost
,
RateMultiplier
:
multiplier
,
AccountRateMultiplier
:
&
accountRateMultiplier
,
BillingType
:
billingType
,
Stream
:
result
.
Stream
,
DurationMs
:
&
durationMs
,
...
...
Prev
1
2
3
4
5
6
7
8
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment