Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
b9b4db3d
Commit
b9b4db3d
authored
Jan 17, 2026
by
song
Browse files
Merge upstream/main
parents
5a6f60a9
dae0d532
Changes
237
Show whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
237 of 237+
files are displayed.
Plain diff
Email patch
backend/internal/server/router.go
View file @
b9b4db3d
package
server
import
(
"log"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
...
...
@@ -9,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// SetupRouter 配置路由器中间件和路由
...
...
@@ -20,20 +23,31 @@ func SetupRouter(
apiKeyAuth
middleware2
.
APIKeyAuthMiddleware
,
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
opsService
*
service
.
OpsService
,
settingService
*
service
.
SettingService
,
cfg
*
config
.
Config
,
redisClient
*
redis
.
Client
,
)
*
gin
.
Engine
{
// 应用中间件
r
.
Use
(
middleware2
.
Logger
())
r
.
Use
(
middleware2
.
CORS
(
cfg
.
CORS
))
r
.
Use
(
middleware2
.
SecurityHeaders
(
cfg
.
Security
.
CSP
))
// Serve embedded frontend if available
// Serve embedded frontend
with settings injection
if available
if
web
.
HasEmbeddedFrontend
()
{
frontendServer
,
err
:=
web
.
NewFrontendServer
(
settingService
)
if
err
!=
nil
{
log
.
Printf
(
"Warning: Failed to create frontend server with settings injection: %v, using legacy mode"
,
err
)
r
.
Use
(
web
.
ServeEmbeddedFrontend
())
}
else
{
// Register cache invalidation callback
settingService
.
SetOnUpdateCallback
(
frontendServer
.
InvalidateCache
)
r
.
Use
(
frontendServer
.
Middleware
())
}
}
// 注册路由
registerRoutes
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
registerRoutes
(
r
,
handlers
,
jwtAuth
,
adminAuth
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
opsService
,
cfg
,
redisClient
)
return
r
}
...
...
@@ -47,7 +61,9 @@ func registerRoutes(
apiKeyAuth
middleware2
.
APIKeyAuthMiddleware
,
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
opsService
*
service
.
OpsService
,
cfg
*
config
.
Config
,
redisClient
*
redis
.
Client
,
)
{
// 通用路由(健康检查、状态等)
routes
.
RegisterCommonRoutes
(
r
)
...
...
@@ -56,8 +72,8 @@ func registerRoutes(
v1
:=
r
.
Group
(
"/api/v1"
)
// 注册各模块路由
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterAuthRoutes
(
v1
,
h
,
jwtAuth
,
redisClient
)
routes
.
RegisterUserRoutes
(
v1
,
h
,
jwtAuth
)
routes
.
RegisterAdminRoutes
(
v1
,
h
,
adminAuth
)
routes
.
RegisterGatewayRoutes
(
r
,
h
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
cfg
)
routes
.
RegisterGatewayRoutes
(
r
,
h
,
apiKeyAuth
,
apiKeyService
,
subscriptionService
,
opsService
,
cfg
)
}
backend/internal/server/routes/admin.go
View file @
b9b4db3d
...
...
@@ -44,9 +44,15 @@ func RegisterAdminRoutes(
// 卡密管理
registerRedeemCodeRoutes
(
admin
,
h
)
// 优惠码管理
registerPromoCodeRoutes
(
admin
,
h
)
// 系统设置
registerSettingsRoutes
(
admin
,
h
)
// 运维监控(Ops)
registerOpsRoutes
(
admin
,
h
)
// 系统管理
registerSystemRoutes
(
admin
,
h
)
...
...
@@ -61,6 +67,85 @@ func RegisterAdminRoutes(
}
}
func
registerOpsRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
ops
:=
admin
.
Group
(
"/ops"
)
{
// Realtime ops signals
ops
.
GET
(
"/concurrency"
,
h
.
Admin
.
Ops
.
GetConcurrencyStats
)
ops
.
GET
(
"/account-availability"
,
h
.
Admin
.
Ops
.
GetAccountAvailability
)
ops
.
GET
(
"/realtime-traffic"
,
h
.
Admin
.
Ops
.
GetRealtimeTrafficSummary
)
// Alerts (rules + events)
ops
.
GET
(
"/alert-rules"
,
h
.
Admin
.
Ops
.
ListAlertRules
)
ops
.
POST
(
"/alert-rules"
,
h
.
Admin
.
Ops
.
CreateAlertRule
)
ops
.
PUT
(
"/alert-rules/:id"
,
h
.
Admin
.
Ops
.
UpdateAlertRule
)
ops
.
DELETE
(
"/alert-rules/:id"
,
h
.
Admin
.
Ops
.
DeleteAlertRule
)
ops
.
GET
(
"/alert-events"
,
h
.
Admin
.
Ops
.
ListAlertEvents
)
ops
.
GET
(
"/alert-events/:id"
,
h
.
Admin
.
Ops
.
GetAlertEvent
)
ops
.
PUT
(
"/alert-events/:id/status"
,
h
.
Admin
.
Ops
.
UpdateAlertEventStatus
)
ops
.
POST
(
"/alert-silences"
,
h
.
Admin
.
Ops
.
CreateAlertSilence
)
// Email notification config (DB-backed)
ops
.
GET
(
"/email-notification/config"
,
h
.
Admin
.
Ops
.
GetEmailNotificationConfig
)
ops
.
PUT
(
"/email-notification/config"
,
h
.
Admin
.
Ops
.
UpdateEmailNotificationConfig
)
// Runtime settings (DB-backed)
runtime
:=
ops
.
Group
(
"/runtime"
)
{
runtime
.
GET
(
"/alert"
,
h
.
Admin
.
Ops
.
GetAlertRuntimeSettings
)
runtime
.
PUT
(
"/alert"
,
h
.
Admin
.
Ops
.
UpdateAlertRuntimeSettings
)
}
// Advanced settings (DB-backed)
ops
.
GET
(
"/advanced-settings"
,
h
.
Admin
.
Ops
.
GetAdvancedSettings
)
ops
.
PUT
(
"/advanced-settings"
,
h
.
Admin
.
Ops
.
UpdateAdvancedSettings
)
// Settings group (DB-backed)
settings
:=
ops
.
Group
(
"/settings"
)
{
settings
.
GET
(
"/metric-thresholds"
,
h
.
Admin
.
Ops
.
GetMetricThresholds
)
settings
.
PUT
(
"/metric-thresholds"
,
h
.
Admin
.
Ops
.
UpdateMetricThresholds
)
}
// WebSocket realtime (QPS/TPS)
ws
:=
ops
.
Group
(
"/ws"
)
{
ws
.
GET
(
"/qps"
,
h
.
Admin
.
Ops
.
QPSWSHandler
)
}
// Error logs (legacy)
ops
.
GET
(
"/errors"
,
h
.
Admin
.
Ops
.
GetErrorLogs
)
ops
.
GET
(
"/errors/:id"
,
h
.
Admin
.
Ops
.
GetErrorLogByID
)
ops
.
GET
(
"/errors/:id/retries"
,
h
.
Admin
.
Ops
.
ListRetryAttempts
)
ops
.
POST
(
"/errors/:id/retry"
,
h
.
Admin
.
Ops
.
RetryErrorRequest
)
ops
.
PUT
(
"/errors/:id/resolve"
,
h
.
Admin
.
Ops
.
UpdateErrorResolution
)
// Request errors (client-visible failures)
ops
.
GET
(
"/request-errors"
,
h
.
Admin
.
Ops
.
ListRequestErrors
)
ops
.
GET
(
"/request-errors/:id"
,
h
.
Admin
.
Ops
.
GetRequestError
)
ops
.
GET
(
"/request-errors/:id/upstream-errors"
,
h
.
Admin
.
Ops
.
ListRequestErrorUpstreamErrors
)
ops
.
POST
(
"/request-errors/:id/retry-client"
,
h
.
Admin
.
Ops
.
RetryRequestErrorClient
)
ops
.
POST
(
"/request-errors/:id/upstream-errors/:idx/retry"
,
h
.
Admin
.
Ops
.
RetryRequestErrorUpstreamEvent
)
ops
.
PUT
(
"/request-errors/:id/resolve"
,
h
.
Admin
.
Ops
.
ResolveRequestError
)
// Upstream errors (independent upstream failures)
ops
.
GET
(
"/upstream-errors"
,
h
.
Admin
.
Ops
.
ListUpstreamErrors
)
ops
.
GET
(
"/upstream-errors/:id"
,
h
.
Admin
.
Ops
.
GetUpstreamError
)
ops
.
POST
(
"/upstream-errors/:id/retry"
,
h
.
Admin
.
Ops
.
RetryUpstreamError
)
ops
.
PUT
(
"/upstream-errors/:id/resolve"
,
h
.
Admin
.
Ops
.
ResolveUpstreamError
)
// Request drilldown (success + error)
ops
.
GET
(
"/requests"
,
h
.
Admin
.
Ops
.
ListRequestDetails
)
// Dashboard (vNext - raw path for MVP)
ops
.
GET
(
"/dashboard/overview"
,
h
.
Admin
.
Ops
.
GetDashboardOverview
)
ops
.
GET
(
"/dashboard/throughput-trend"
,
h
.
Admin
.
Ops
.
GetDashboardThroughputTrend
)
ops
.
GET
(
"/dashboard/latency-histogram"
,
h
.
Admin
.
Ops
.
GetDashboardLatencyHistogram
)
ops
.
GET
(
"/dashboard/error-trend"
,
h
.
Admin
.
Ops
.
GetDashboardErrorTrend
)
ops
.
GET
(
"/dashboard/error-distribution"
,
h
.
Admin
.
Ops
.
GetDashboardErrorDistribution
)
}
}
func
registerDashboardRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
dashboard
:=
admin
.
Group
(
"/dashboard"
)
{
...
...
@@ -72,6 +157,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard
.
GET
(
"/users-trend"
,
h
.
Admin
.
Dashboard
.
GetUserUsageTrend
)
dashboard
.
POST
(
"/users-usage"
,
h
.
Admin
.
Dashboard
.
GetBatchUsersUsage
)
dashboard
.
POST
(
"/api-keys-usage"
,
h
.
Admin
.
Dashboard
.
GetBatchAPIKeysUsage
)
dashboard
.
POST
(
"/aggregation/backfill"
,
h
.
Admin
.
Dashboard
.
BackfillAggregation
)
}
}
...
...
@@ -183,6 +269,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies
.
POST
(
"/:id/test"
,
h
.
Admin
.
Proxy
.
Test
)
proxies
.
GET
(
"/:id/stats"
,
h
.
Admin
.
Proxy
.
GetStats
)
proxies
.
GET
(
"/:id/accounts"
,
h
.
Admin
.
Proxy
.
GetProxyAccounts
)
proxies
.
POST
(
"/batch-delete"
,
h
.
Admin
.
Proxy
.
BatchDelete
)
proxies
.
POST
(
"/batch"
,
h
.
Admin
.
Proxy
.
BatchCreate
)
}
}
...
...
@@ -201,6 +288,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func
registerPromoCodeRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
promoCodes
:=
admin
.
Group
(
"/promo-codes"
)
{
promoCodes
.
GET
(
""
,
h
.
Admin
.
Promo
.
List
)
promoCodes
.
GET
(
"/:id"
,
h
.
Admin
.
Promo
.
GetByID
)
promoCodes
.
POST
(
""
,
h
.
Admin
.
Promo
.
Create
)
promoCodes
.
PUT
(
"/:id"
,
h
.
Admin
.
Promo
.
Update
)
promoCodes
.
DELETE
(
"/:id"
,
h
.
Admin
.
Promo
.
Delete
)
promoCodes
.
GET
(
"/:id/usages"
,
h
.
Admin
.
Promo
.
GetUsages
)
}
}
func
registerSettingsRoutes
(
admin
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
)
{
adminSettings
:=
admin
.
Group
(
"/settings"
)
{
...
...
@@ -212,6 +311,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings
.
GET
(
"/admin-api-key"
,
h
.
Admin
.
Setting
.
GetAdminAPIKey
)
adminSettings
.
POST
(
"/admin-api-key/regenerate"
,
h
.
Admin
.
Setting
.
RegenerateAdminAPIKey
)
adminSettings
.
DELETE
(
"/admin-api-key"
,
h
.
Admin
.
Setting
.
DeleteAdminAPIKey
)
// 流超时处理配置
adminSettings
.
GET
(
"/stream-timeout"
,
h
.
Admin
.
Setting
.
GetStreamTimeoutSettings
)
adminSettings
.
PUT
(
"/stream-timeout"
,
h
.
Admin
.
Setting
.
UpdateStreamTimeoutSettings
)
}
}
...
...
backend/internal/server/routes/auth.go
View file @
b9b4db3d
package
routes
import
(
"time"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RegisterAuthRoutes 注册认证相关路由
func
RegisterAuthRoutes
(
v1
*
gin
.
RouterGroup
,
h
*
handler
.
Handlers
,
jwtAuth
middleware
.
JWTAuthMiddleware
,
jwtAuth
servermiddleware
.
JWTAuthMiddleware
,
redisClient
*
redis
.
Client
,
)
{
// 创建速率限制器
rateLimiter
:=
middleware
.
NewRateLimiter
(
redisClient
)
// 公开接口
auth
:=
v1
.
Group
(
"/auth"
)
{
auth
.
POST
(
"/register"
,
h
.
Auth
.
Register
)
auth
.
POST
(
"/login"
,
h
.
Auth
.
Login
)
auth
.
POST
(
"/send-verify-code"
,
h
.
Auth
.
SendVerifyCode
)
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth
.
POST
(
"/validate-promo-code"
,
rateLimiter
.
LimitWithOptions
(
"validate-promo"
,
10
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ValidatePromoCode
)
auth
.
GET
(
"/oauth/linuxdo/start"
,
h
.
Auth
.
LinuxDoOAuthStart
)
auth
.
GET
(
"/oauth/linuxdo/callback"
,
h
.
Auth
.
LinuxDoOAuthCallback
)
}
...
...
backend/internal/server/routes/gateway.go
View file @
b9b4db3d
...
...
@@ -16,13 +16,18 @@ func RegisterGatewayRoutes(
apiKeyAuth
middleware
.
APIKeyAuthMiddleware
,
apiKeyService
*
service
.
APIKeyService
,
subscriptionService
*
service
.
SubscriptionService
,
opsService
*
service
.
OpsService
,
cfg
*
config
.
Config
,
)
{
bodyLimit
:=
middleware
.
RequestBodyLimit
(
cfg
.
Gateway
.
MaxBodySize
)
clientRequestID
:=
middleware
.
ClientRequestID
()
opsErrorLogger
:=
handler
.
OpsErrorLoggerMiddleware
(
opsService
)
// API网关(Claude API兼容)
gateway
:=
r
.
Group
(
"/v1"
)
gateway
.
Use
(
bodyLimit
)
gateway
.
Use
(
clientRequestID
)
gateway
.
Use
(
opsErrorLogger
)
gateway
.
Use
(
gin
.
HandlerFunc
(
apiKeyAuth
))
{
gateway
.
POST
(
"/messages"
,
h
.
Gateway
.
Messages
)
...
...
@@ -36,6 +41,8 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
gemini
:=
r
.
Group
(
"/v1beta"
)
gemini
.
Use
(
bodyLimit
)
gemini
.
Use
(
clientRequestID
)
gemini
.
Use
(
opsErrorLogger
)
gemini
.
Use
(
middleware
.
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
subscriptionService
,
cfg
))
{
gemini
.
GET
(
"/models"
,
h
.
Gateway
.
GeminiV1BetaListModels
)
...
...
@@ -45,7 +52,7 @@ func RegisterGatewayRoutes(
}
// OpenAI Responses API(不带v1前缀的别名)
r
.
POST
(
"/responses"
,
bodyLimit
,
gin
.
HandlerFunc
(
apiKeyAuth
),
h
.
OpenAIGateway
.
Responses
)
r
.
POST
(
"/responses"
,
bodyLimit
,
clientRequestID
,
opsErrorLogger
,
gin
.
HandlerFunc
(
apiKeyAuth
),
h
.
OpenAIGateway
.
Responses
)
// Antigravity 模型列表
r
.
GET
(
"/antigravity/models"
,
gin
.
HandlerFunc
(
apiKeyAuth
),
h
.
Gateway
.
AntigravityModels
)
...
...
@@ -53,6 +60,8 @@ func RegisterGatewayRoutes(
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1
:=
r
.
Group
(
"/antigravity/v1"
)
antigravityV1
.
Use
(
bodyLimit
)
antigravityV1
.
Use
(
clientRequestID
)
antigravityV1
.
Use
(
opsErrorLogger
)
antigravityV1
.
Use
(
middleware
.
ForcePlatform
(
service
.
PlatformAntigravity
))
antigravityV1
.
Use
(
gin
.
HandlerFunc
(
apiKeyAuth
))
{
...
...
@@ -64,6 +73,8 @@ func RegisterGatewayRoutes(
antigravityV1Beta
:=
r
.
Group
(
"/antigravity/v1beta"
)
antigravityV1Beta
.
Use
(
bodyLimit
)
antigravityV1Beta
.
Use
(
clientRequestID
)
antigravityV1Beta
.
Use
(
opsErrorLogger
)
antigravityV1Beta
.
Use
(
middleware
.
ForcePlatform
(
service
.
PlatformAntigravity
))
antigravityV1Beta
.
Use
(
middleware
.
APIKeyAuthWithSubscriptionGoogle
(
apiKeyService
,
subscriptionService
,
cfg
))
{
...
...
backend/internal/service/account.go
View file @
b9b4db3d
...
...
@@ -19,6 +19,9 @@ type Account struct {
ProxyID
*
int64
Concurrency
int
Priority
int
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
RateMultiplier
*
float64
Status
string
ErrorMessage
string
LastUsedAt
*
time
.
Time
...
...
@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool {
return
a
.
Status
==
StatusActive
}
// BillingRateMultiplier 返回账号计费倍率。
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
// - 允许 0,表示该账号计费为 0
// - 负数属于非法数据,出于安全考虑按 1.0 处理
func
(
a
*
Account
)
BillingRateMultiplier
()
float64
{
if
a
==
nil
||
a
.
RateMultiplier
==
nil
{
return
1.0
}
if
*
a
.
RateMultiplier
<
0
{
return
1.0
}
return
*
a
.
RateMultiplier
}
func
(
a
*
Account
)
IsSchedulable
()
bool
{
if
!
a
.
IsActive
()
||
!
a
.
Schedulable
{
return
false
...
...
@@ -540,3 +557,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
}
return
false
}
// WindowCostSchedulability 窗口费用调度状态
type
WindowCostSchedulability
int
const
(
// WindowCostSchedulable 可正常调度
WindowCostSchedulable
WindowCostSchedulability
=
iota
// WindowCostStickyOnly 仅允许粘性会话
WindowCostStickyOnly
// WindowCostNotSchedulable 完全不可调度
WindowCostNotSchedulable
)
// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
func
(
a
*
Account
)
IsAnthropicOAuthOrSetupToken
()
bool
{
return
a
.
Platform
==
PlatformAnthropic
&&
(
a
.
Type
==
AccountTypeOAuth
||
a
.
Type
==
AccountTypeSetupToken
)
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func
(
a
*
Account
)
GetWindowCostLimit
()
float64
{
if
a
.
Extra
==
nil
{
return
0
}
if
v
,
ok
:=
a
.
Extra
[
"window_cost_limit"
];
ok
{
return
parseExtraFloat64
(
v
)
}
return
0
}
// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
// 默认值为 10
func
(
a
*
Account
)
GetWindowCostStickyReserve
()
float64
{
if
a
.
Extra
==
nil
{
return
10.0
}
if
v
,
ok
:=
a
.
Extra
[
"window_cost_sticky_reserve"
];
ok
{
val
:=
parseExtraFloat64
(
v
)
if
val
>
0
{
return
val
}
}
return
10.0
}
// GetMaxSessions 获取最大并发会话数
// 返回 0 表示未启用
func
(
a
*
Account
)
GetMaxSessions
()
int
{
if
a
.
Extra
==
nil
{
return
0
}
if
v
,
ok
:=
a
.
Extra
[
"max_sessions"
];
ok
{
return
parseExtraInt
(
v
)
}
return
0
}
// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
// 默认值为 5 分钟
func
(
a
*
Account
)
GetSessionIdleTimeoutMinutes
()
int
{
if
a
.
Extra
==
nil
{
return
5
}
if
v
,
ok
:=
a
.
Extra
[
"session_idle_timeout_minutes"
];
ok
{
val
:=
parseExtraInt
(
v
)
if
val
>
0
{
return
val
}
}
return
5
}
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
func
(
a
*
Account
)
CheckWindowCostSchedulability
(
currentWindowCost
float64
)
WindowCostSchedulability
{
limit
:=
a
.
GetWindowCostLimit
()
if
limit
<=
0
{
return
WindowCostSchedulable
}
if
currentWindowCost
<
limit
{
return
WindowCostSchedulable
}
stickyReserve
:=
a
.
GetWindowCostStickyReserve
()
if
currentWindowCost
<
limit
+
stickyReserve
{
return
WindowCostStickyOnly
}
return
WindowCostNotSchedulable
}
// parseExtraFloat64 从 extra 字段解析 float64 值
func
parseExtraFloat64
(
value
any
)
float64
{
switch
v
:=
value
.
(
type
)
{
case
float64
:
return
v
case
float32
:
return
float64
(
v
)
case
int
:
return
float64
(
v
)
case
int64
:
return
float64
(
v
)
case
json
.
Number
:
if
f
,
err
:=
v
.
Float64
();
err
==
nil
{
return
f
}
case
string
:
if
f
,
err
:=
strconv
.
ParseFloat
(
strings
.
TrimSpace
(
v
),
64
);
err
==
nil
{
return
f
}
}
return
0
}
// parseExtraInt 从 extra 字段解析 int 值
func
parseExtraInt
(
value
any
)
int
{
switch
v
:=
value
.
(
type
)
{
case
int
:
return
v
case
int64
:
return
int
(
v
)
case
float64
:
return
int
(
v
)
case
json
.
Number
:
if
i
,
err
:=
v
.
Int64
();
err
==
nil
{
return
int
(
i
)
}
case
string
:
if
i
,
err
:=
strconv
.
Atoi
(
strings
.
TrimSpace
(
v
));
err
==
nil
{
return
i
}
}
return
0
}
backend/internal/service/account_billing_rate_multiplier_test.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func
TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil
(
t
*
testing
.
T
)
{
var
a
Account
require
.
NoError
(
t
,
json
.
Unmarshal
([]
byte
(
`{"id":1,"name":"acc","status":"active"}`
),
&
a
))
require
.
Nil
(
t
,
a
.
RateMultiplier
)
require
.
Equal
(
t
,
1.0
,
a
.
BillingRateMultiplier
())
}
func
TestAccount_BillingRateMultiplier_AllowsZero
(
t
*
testing
.
T
)
{
v
:=
0.0
a
:=
Account
{
RateMultiplier
:
&
v
}
require
.
Equal
(
t
,
0.0
,
a
.
BillingRateMultiplier
())
}
func
TestAccount_BillingRateMultiplier_NegativeFallsBackToOne
(
t
*
testing
.
T
)
{
v
:=
-
1.0
a
:=
Account
{
RateMultiplier
:
&
v
}
require
.
Equal
(
t
,
1.0
,
a
.
BillingRateMultiplier
())
}
backend/internal/service/account_service.go
View file @
b9b4db3d
...
...
@@ -51,11 +51,13 @@ type AccountRepository interface {
SetRateLimited
(
ctx
context
.
Context
,
id
int64
,
resetAt
time
.
Time
)
error
SetAntigravityQuotaScopeLimit
(
ctx
context
.
Context
,
id
int64
,
scope
AntigravityQuotaScope
,
resetAt
time
.
Time
)
error
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
ClearTempUnschedulable
(
ctx
context
.
Context
,
id
int64
)
error
ClearRateLimit
(
ctx
context
.
Context
,
id
int64
)
error
ClearAntigravityQuotaScopes
(
ctx
context
.
Context
,
id
int64
)
error
ClearModelRateLimits
(
ctx
context
.
Context
,
id
int64
)
error
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
UpdateExtra
(
ctx
context
.
Context
,
id
int64
,
updates
map
[
string
]
any
)
error
BulkUpdate
(
ctx
context
.
Context
,
ids
[]
int64
,
updates
AccountBulkUpdate
)
(
int64
,
error
)
...
...
@@ -68,6 +70,7 @@ type AccountBulkUpdate struct {
ProxyID
*
int64
Concurrency
*
int
Priority
*
int
RateMultiplier
*
float64
Status
*
string
Schedulable
*
bool
Credentials
map
[
string
]
any
...
...
backend/internal/service/account_service_delete_test.go
View file @
b9b4db3d
...
...
@@ -147,6 +147,10 @@ func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id
panic
(
"unexpected SetAntigravityQuotaScopeLimit call"
)
}
func
(
s
*
accountRepoStub
)
SetModelRateLimit
(
ctx
context
.
Context
,
id
int64
,
scope
string
,
resetAt
time
.
Time
)
error
{
panic
(
"unexpected SetModelRateLimit call"
)
}
func
(
s
*
accountRepoStub
)
SetOverloaded
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
)
error
{
panic
(
"unexpected SetOverloaded call"
)
}
...
...
@@ -167,6 +171,10 @@ func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id in
panic
(
"unexpected ClearAntigravityQuotaScopes call"
)
}
func
(
s
*
accountRepoStub
)
ClearModelRateLimits
(
ctx
context
.
Context
,
id
int64
)
error
{
panic
(
"unexpected ClearModelRateLimits call"
)
}
func
(
s
*
accountRepoStub
)
UpdateSessionWindow
(
ctx
context
.
Context
,
id
int64
,
start
,
end
*
time
.
Time
,
status
string
)
error
{
panic
(
"unexpected UpdateSessionWindow call"
)
}
...
...
backend/internal/service/account_usage_service.go
View file @
b9b4db3d
...
...
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats
(
ctx
context
.
Context
)
(
*
usagestats
.
DashboardStats
,
error
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
int64
)
([]
usagestats
.
TrendDataPoint
,
error
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
int64
)
([]
usagestats
.
ModelStat
,
error
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
stream
*
bool
)
([]
usagestats
.
TrendDataPoint
,
error
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
stream
*
bool
)
([]
usagestats
.
ModelStat
,
error
)
GetAPIKeyUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
APIKeyUsageTrendPoint
,
error
)
GetUserUsageTrend
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
limit
int
)
([]
usagestats
.
UserUsageTrendPoint
,
error
)
GetBatchUserUsageStats
(
ctx
context
.
Context
,
userIDs
[]
int64
)
(
map
[
int64
]
*
usagestats
.
BatchUserUsageStats
,
error
)
...
...
@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache {
}
// WindowStats 窗口期统计
//
// cost: 账号口径费用(total_cost * account_rate_multiplier)
// standard_cost: 标准费用(total_cost,不含倍率)
// user_cost: 用户/API Key 口径费用(actual_cost,受分组倍率影响)
type
WindowStats
struct
{
Requests
int64
`json:"requests"`
Tokens
int64
`json:"tokens"`
Cost
float64
`json:"cost"`
StandardCost
float64
`json:"standard_cost"`
UserCost
float64
`json:"user_cost"`
}
// UsageProgress 使用量进度
...
...
@@ -266,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart
:=
geminiDailyWindowStart
(
now
)
stats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
dayStart
,
now
,
0
,
0
,
account
.
ID
)
stats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
dayStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get gemini usage stats failed: %w"
,
err
)
}
...
...
@@ -288,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart
:=
now
.
Truncate
(
time
.
Minute
)
minuteResetAt
:=
minuteStart
.
Add
(
time
.
Minute
)
minuteStats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
minuteStart
,
now
,
0
,
0
,
account
.
ID
)
minuteStats
,
err
:=
s
.
usageLogRepo
.
GetModelStatsWithFilters
(
ctx
,
minuteStart
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"get gemini minute usage stats failed: %w"
,
err
)
}
...
...
@@ -380,6 +386,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
Requests
:
stats
.
Requests
,
Tokens
:
stats
.
Tokens
,
Cost
:
stats
.
Cost
,
StandardCost
:
stats
.
StandardCost
,
UserCost
:
stats
.
UserCost
,
}
// 缓存窗口统计(1 分钟)
...
...
@@ -406,6 +414,8 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
Requests
:
stats
.
Requests
,
Tokens
:
stats
.
Tokens
,
Cost
:
stats
.
Cost
,
StandardCost
:
stats
.
StandardCost
,
UserCost
:
stats
.
UserCost
,
},
nil
}
...
...
@@ -565,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
},
}
}
// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
// 用于账号列表页面显示当前窗口费用
func
(
s
*
AccountUsageService
)
GetAccountWindowStats
(
ctx
context
.
Context
,
accountID
int64
,
startTime
time
.
Time
)
(
*
usagestats
.
AccountStats
,
error
)
{
return
s
.
usageLogRepo
.
GetAccountWindowStats
(
ctx
,
accountID
,
startTime
)
}
backend/internal/service/admin_service.go
View file @
b9b4db3d
...
...
@@ -55,7 +55,8 @@ type AdminService interface {
CreateProxy
(
ctx
context
.
Context
,
input
*
CreateProxyInput
)
(
*
Proxy
,
error
)
UpdateProxy
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateProxyInput
)
(
*
Proxy
,
error
)
DeleteProxy
(
ctx
context
.
Context
,
id
int64
)
error
GetProxyAccounts
(
ctx
context
.
Context
,
proxyID
int64
,
page
,
pageSize
int
)
([]
Account
,
int64
,
error
)
BatchDeleteProxies
(
ctx
context
.
Context
,
ids
[]
int64
)
(
*
ProxyBatchDeleteResult
,
error
)
GetProxyAccounts
(
ctx
context
.
Context
,
proxyID
int64
)
([]
ProxyAccountSummary
,
error
)
CheckProxyExists
(
ctx
context
.
Context
,
host
string
,
port
int
,
username
,
password
string
)
(
bool
,
error
)
TestProxy
(
ctx
context
.
Context
,
id
int64
)
(
*
ProxyTestResult
,
error
)
...
...
@@ -106,6 +107,9 @@ type CreateGroupInput struct {
ImagePrice4K
*
float64
ClaudeCodeOnly
bool
// 仅允许 Claude Code 客户端
FallbackGroupID
*
int64
// 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
bool
// 是否启用模型路由
}
type
UpdateGroupInput
struct
{
...
...
@@ -125,6 +129,9 @@ type UpdateGroupInput struct {
ImagePrice4K
*
float64
ClaudeCodeOnly
*
bool
// 仅允许 Claude Code 客户端
FallbackGroupID
*
int64
// 降级分组 ID
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
ModelRoutingEnabled
*
bool
// 是否启用模型路由
}
type
CreateAccountInput
struct
{
...
...
@@ -137,6 +144,7 @@ type CreateAccountInput struct {
ProxyID
*
int64
Concurrency
int
Priority
int
RateMultiplier
*
float64
// 账号计费倍率(>=0,允许 0)
GroupIDs
[]
int64
ExpiresAt
*
int64
AutoPauseOnExpired
*
bool
...
...
@@ -154,6 +162,7 @@ type UpdateAccountInput struct {
ProxyID
*
int64
Concurrency
*
int
// 使用指针区分"未提供"和"设置为0"
Priority
*
int
// 使用指针区分"未提供"和"设置为0"
RateMultiplier
*
float64
// 账号计费倍率(>=0,允许 0)
Status
string
GroupIDs
*
[]
int64
ExpiresAt
*
int64
...
...
@@ -168,6 +177,7 @@ type BulkUpdateAccountsInput struct {
ProxyID
*
int64
Concurrency
*
int
Priority
*
int
RateMultiplier
*
float64
// 账号计费倍率(>=0,允许 0)
Status
string
Schedulable
*
bool
GroupIDs
*
[]
int64
...
...
@@ -189,6 +199,8 @@ type BulkUpdateAccountResult struct {
type
BulkUpdateAccountsResult
struct
{
Success
int
`json:"success"`
Failed
int
`json:"failed"`
SuccessIDs
[]
int64
`json:"success_ids"`
FailedIDs
[]
int64
`json:"failed_ids"`
Results
[]
BulkUpdateAccountResult
`json:"results"`
}
...
...
@@ -219,6 +231,16 @@ type GenerateRedeemCodesInput struct {
ValidityDays
int
// 订阅类型专用:有效天数
}
type
ProxyBatchDeleteResult
struct
{
DeletedIDs
[]
int64
`json:"deleted_ids"`
Skipped
[]
ProxyBatchDeleteSkipped
`json:"skipped"`
}
type
ProxyBatchDeleteSkipped
struct
{
ID
int64
`json:"id"`
Reason
string
`json:"reason"`
}
// ProxyTestResult represents the result of testing a proxy
type
ProxyTestResult
struct
{
Success
bool
`json:"success"`
...
...
@@ -228,14 +250,16 @@ type ProxyTestResult struct {
City
string
`json:"city,omitempty"`
Region
string
`json:"region,omitempty"`
Country
string
`json:"country,omitempty"`
CountryCode
string
`json:"country_code,omitempty"`
}
// ProxyExitInfo represents proxy exit information from ip
info.io
// ProxyExitInfo represents proxy exit information from ip
-api.com
type
ProxyExitInfo
struct
{
IP
string
City
string
Region
string
Country
string
CountryCode
string
}
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
...
...
@@ -253,6 +277,8 @@ type adminServiceImpl struct {
redeemCodeRepo
RedeemCodeRepository
billingCacheService
*
BillingCacheService
proxyProber
ProxyExitInfoProber
proxyLatencyCache
ProxyLatencyCache
authCacheInvalidator
APIKeyAuthCacheInvalidator
}
// NewAdminService creates a new AdminService
...
...
@@ -265,6 +291,8 @@ func NewAdminService(
redeemCodeRepo
RedeemCodeRepository
,
billingCacheService
*
BillingCacheService
,
proxyProber
ProxyExitInfoProber
,
proxyLatencyCache
ProxyLatencyCache
,
authCacheInvalidator
APIKeyAuthCacheInvalidator
,
)
AdminService
{
return
&
adminServiceImpl
{
userRepo
:
userRepo
,
...
...
@@ -275,6 +303,8 @@ func NewAdminService(
redeemCodeRepo
:
redeemCodeRepo
,
billingCacheService
:
billingCacheService
,
proxyProber
:
proxyProber
,
proxyLatencyCache
:
proxyLatencyCache
,
authCacheInvalidator
:
authCacheInvalidator
,
}
}
...
...
@@ -324,6 +354,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
}
oldConcurrency
:=
user
.
Concurrency
oldStatus
:=
user
.
Status
oldRole
:=
user
.
Role
if
input
.
Email
!=
""
{
user
.
Email
=
input
.
Email
...
...
@@ -356,6 +388,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
return
nil
,
err
}
if
s
.
authCacheInvalidator
!=
nil
{
if
user
.
Concurrency
!=
oldConcurrency
||
user
.
Status
!=
oldStatus
||
user
.
Role
!=
oldRole
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByUserID
(
ctx
,
user
.
ID
)
}
}
concurrencyDiff
:=
user
.
Concurrency
-
oldConcurrency
if
concurrencyDiff
!=
0
{
...
...
@@ -394,6 +431,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
log
.
Printf
(
"delete user failed: user_id=%d err=%v"
,
id
,
err
)
return
err
}
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByUserID
(
ctx
,
id
)
}
return
nil
}
...
...
@@ -421,6 +461,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
if
err
:=
s
.
userRepo
.
Update
(
ctx
,
user
);
err
!=
nil
{
return
nil
,
err
}
balanceDiff
:=
user
.
Balance
-
oldBalance
if
s
.
authCacheInvalidator
!=
nil
&&
balanceDiff
!=
0
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByUserID
(
ctx
,
userID
)
}
if
s
.
billingCacheService
!=
nil
{
go
func
()
{
...
...
@@ -432,7 +476,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
}()
}
balanceDiff
:=
user
.
Balance
-
oldBalance
if
balanceDiff
!=
0
{
code
,
err
:=
GenerateRedeemCode
()
if
err
!=
nil
{
...
...
@@ -545,6 +588,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice4K
:
imagePrice4K
,
ClaudeCodeOnly
:
input
.
ClaudeCodeOnly
,
FallbackGroupID
:
input
.
FallbackGroupID
,
ModelRouting
:
input
.
ModelRouting
,
}
if
err
:=
s
.
groupRepo
.
Create
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
...
...
@@ -577,18 +621,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
return
fmt
.
Errorf
(
"cannot set self as fallback group"
)
}
visited
:=
map
[
int64
]
struct
{}{}
nextID
:=
fallbackGroupID
for
{
if
_
,
seen
:=
visited
[
nextID
];
seen
{
return
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
visited
[
nextID
]
=
struct
{}{}
if
currentGroupID
>
0
&&
nextID
==
currentGroupID
{
return
fmt
.
Errorf
(
"fallback group cycle detected"
)
}
// 检查降级分组是否存在
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByID
(
ctx
,
fallbackGroup
ID
)
fallbackGroup
,
err
:=
s
.
groupRepo
.
GetByID
Lite
(
ctx
,
next
ID
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"fallback group not found: %w"
,
err
)
}
// 降级分组不能启用 claude_code_only,否则会造成死循环
if
fallbackGroup
.
ClaudeCodeOnly
{
if
nextID
==
fallbackGroupID
&&
fallbackGroup
.
ClaudeCodeOnly
{
return
fmt
.
Errorf
(
"fallback group cannot have claude_code_only enabled"
)
}
if
fallbackGroup
.
FallbackGroupID
==
nil
{
return
nil
}
nextID
=
*
fallbackGroup
.
FallbackGroupID
}
}
func
(
s
*
adminServiceImpl
)
UpdateGroup
(
ctx
context
.
Context
,
id
int64
,
input
*
UpdateGroupInput
)
(
*
Group
,
error
)
{
...
...
@@ -658,13 +717,32 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
}
// 模型路由配置
if
input
.
ModelRouting
!=
nil
{
group
.
ModelRouting
=
input
.
ModelRouting
}
if
input
.
ModelRoutingEnabled
!=
nil
{
group
.
ModelRoutingEnabled
=
*
input
.
ModelRoutingEnabled
}
if
err
:=
s
.
groupRepo
.
Update
(
ctx
,
group
);
err
!=
nil
{
return
nil
,
err
}
if
s
.
authCacheInvalidator
!=
nil
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByGroupID
(
ctx
,
id
)
}
return
group
,
nil
}
func
(
s
*
adminServiceImpl
)
DeleteGroup
(
ctx
context
.
Context
,
id
int64
)
error
{
var
groupKeys
[]
string
if
s
.
authCacheInvalidator
!=
nil
{
keys
,
err
:=
s
.
apiKeyRepo
.
ListKeysByGroupID
(
ctx
,
id
)
if
err
==
nil
{
groupKeys
=
keys
}
}
affectedUserIDs
,
err
:=
s
.
groupRepo
.
DeleteCascade
(
ctx
,
id
)
if
err
!=
nil
{
return
err
...
...
@@ -683,6 +761,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
}
}()
}
if
s
.
authCacheInvalidator
!=
nil
{
for
_
,
key
:=
range
groupKeys
{
s
.
authCacheInvalidator
.
InvalidateAuthCacheByKey
(
ctx
,
key
)
}
}
return
nil
}
...
...
@@ -769,6 +852,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
else
{
account
.
AutoPauseOnExpired
=
true
}
if
input
.
RateMultiplier
!=
nil
{
if
*
input
.
RateMultiplier
<
0
{
return
nil
,
errors
.
New
(
"rate_multiplier must be >= 0"
)
}
account
.
RateMultiplier
=
input
.
RateMultiplier
}
if
err
:=
s
.
accountRepo
.
Create
(
ctx
,
account
);
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -821,6 +910,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if
input
.
Priority
!=
nil
{
account
.
Priority
=
*
input
.
Priority
}
if
input
.
RateMultiplier
!=
nil
{
if
*
input
.
RateMultiplier
<
0
{
return
nil
,
errors
.
New
(
"rate_multiplier must be >= 0"
)
}
account
.
RateMultiplier
=
input
.
RateMultiplier
}
if
input
.
Status
!=
""
{
account
.
Status
=
input
.
Status
}
...
...
@@ -871,6 +966,8 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
// It merges credentials/extra keys instead of overwriting the whole object.
func
(
s
*
adminServiceImpl
)
BulkUpdateAccounts
(
ctx
context
.
Context
,
input
*
BulkUpdateAccountsInput
)
(
*
BulkUpdateAccountsResult
,
error
)
{
result
:=
&
BulkUpdateAccountsResult
{
SuccessIDs
:
make
([]
int64
,
0
,
len
(
input
.
AccountIDs
)),
FailedIDs
:
make
([]
int64
,
0
,
len
(
input
.
AccountIDs
)),
Results
:
make
([]
BulkUpdateAccountResult
,
0
,
len
(
input
.
AccountIDs
)),
}
...
...
@@ -892,6 +989,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
}
if
input
.
RateMultiplier
!=
nil
{
if
*
input
.
RateMultiplier
<
0
{
return
nil
,
errors
.
New
(
"rate_multiplier must be >= 0"
)
}
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates
:=
AccountBulkUpdate
{
Credentials
:
input
.
Credentials
,
...
...
@@ -909,6 +1012,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if
input
.
Priority
!=
nil
{
repoUpdates
.
Priority
=
input
.
Priority
}
if
input
.
RateMultiplier
!=
nil
{
repoUpdates
.
RateMultiplier
=
input
.
RateMultiplier
}
if
input
.
Status
!=
""
{
repoUpdates
.
Status
=
&
input
.
Status
}
...
...
@@ -935,6 +1041,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry
.
Success
=
false
entry
.
Error
=
err
.
Error
()
result
.
Failed
++
result
.
FailedIDs
=
append
(
result
.
FailedIDs
,
accountID
)
result
.
Results
=
append
(
result
.
Results
,
entry
)
continue
}
...
...
@@ -944,6 +1051,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry
.
Success
=
false
entry
.
Error
=
err
.
Error
()
result
.
Failed
++
result
.
FailedIDs
=
append
(
result
.
FailedIDs
,
accountID
)
result
.
Results
=
append
(
result
.
Results
,
entry
)
continue
}
...
...
@@ -953,6 +1061,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry
.
Success
=
false
entry
.
Error
=
err
.
Error
()
result
.
Failed
++
result
.
FailedIDs
=
append
(
result
.
FailedIDs
,
accountID
)
result
.
Results
=
append
(
result
.
Results
,
entry
)
continue
}
...
...
@@ -960,6 +1069,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry
.
Success
=
true
result
.
Success
++
result
.
SuccessIDs
=
append
(
result
.
SuccessIDs
,
accountID
)
result
.
Results
=
append
(
result
.
Results
,
entry
)
}
...
...
@@ -1019,6 +1129,7 @@ func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page
if
err
!=
nil
{
return
nil
,
0
,
err
}
s
.
attachProxyLatency
(
ctx
,
proxies
)
return
proxies
,
result
.
Total
,
nil
}
...
...
@@ -1027,7 +1138,12 @@ func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
}
func
(
s
*
adminServiceImpl
)
GetAllProxiesWithAccountCount
(
ctx
context
.
Context
)
([]
ProxyWithAccountCount
,
error
)
{
return
s
.
proxyRepo
.
ListActiveWithAccountCount
(
ctx
)
proxies
,
err
:=
s
.
proxyRepo
.
ListActiveWithAccountCount
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
s
.
attachProxyLatency
(
ctx
,
proxies
)
return
proxies
,
nil
}
func
(
s
*
adminServiceImpl
)
GetProxy
(
ctx
context
.
Context
,
id
int64
)
(
*
Proxy
,
error
)
{
...
...
@@ -1047,6 +1163,8 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
if
err
:=
s
.
proxyRepo
.
Create
(
ctx
,
proxy
);
err
!=
nil
{
return
nil
,
err
}
// Probe latency asynchronously so creation isn't blocked by network timeout.
go
s
.
probeProxyLatency
(
context
.
Background
(),
proxy
)
return
proxy
,
nil
}
...
...
@@ -1085,12 +1203,53 @@ func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *Upd
}
func
(
s
*
adminServiceImpl
)
DeleteProxy
(
ctx
context
.
Context
,
id
int64
)
error
{
count
,
err
:=
s
.
proxyRepo
.
CountAccountsByProxyID
(
ctx
,
id
)
if
err
!=
nil
{
return
err
}
if
count
>
0
{
return
ErrProxyInUse
}
return
s
.
proxyRepo
.
Delete
(
ctx
,
id
)
}
func
(
s
*
adminServiceImpl
)
GetProxyAccounts
(
ctx
context
.
Context
,
proxyID
int64
,
page
,
pageSize
int
)
([]
Account
,
int64
,
error
)
{
// Return mock data for now - would need a dedicated repository method
return
[]
Account
{},
0
,
nil
func
(
s
*
adminServiceImpl
)
BatchDeleteProxies
(
ctx
context
.
Context
,
ids
[]
int64
)
(
*
ProxyBatchDeleteResult
,
error
)
{
result
:=
&
ProxyBatchDeleteResult
{}
if
len
(
ids
)
==
0
{
return
result
,
nil
}
for
_
,
id
:=
range
ids
{
count
,
err
:=
s
.
proxyRepo
.
CountAccountsByProxyID
(
ctx
,
id
)
if
err
!=
nil
{
result
.
Skipped
=
append
(
result
.
Skipped
,
ProxyBatchDeleteSkipped
{
ID
:
id
,
Reason
:
err
.
Error
(),
})
continue
}
if
count
>
0
{
result
.
Skipped
=
append
(
result
.
Skipped
,
ProxyBatchDeleteSkipped
{
ID
:
id
,
Reason
:
ErrProxyInUse
.
Error
(),
})
continue
}
if
err
:=
s
.
proxyRepo
.
Delete
(
ctx
,
id
);
err
!=
nil
{
result
.
Skipped
=
append
(
result
.
Skipped
,
ProxyBatchDeleteSkipped
{
ID
:
id
,
Reason
:
err
.
Error
(),
})
continue
}
result
.
DeletedIDs
=
append
(
result
.
DeletedIDs
,
id
)
}
return
result
,
nil
}
func
(
s
*
adminServiceImpl
)
GetProxyAccounts
(
ctx
context
.
Context
,
proxyID
int64
)
([]
ProxyAccountSummary
,
error
)
{
return
s
.
proxyRepo
.
ListAccountSummariesByProxyID
(
ctx
,
proxyID
)
}
func
(
s
*
adminServiceImpl
)
CheckProxyExists
(
ctx
context
.
Context
,
host
string
,
port
int
,
username
,
password
string
)
(
bool
,
error
)
{
...
...
@@ -1190,12 +1349,29 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
proxyURL
:=
proxy
.
URL
()
exitInfo
,
latencyMs
,
err
:=
s
.
proxyProber
.
ProbeProxy
(
ctx
,
proxyURL
)
if
err
!=
nil
{
s
.
saveProxyLatency
(
ctx
,
id
,
&
ProxyLatencyInfo
{
Success
:
false
,
Message
:
err
.
Error
(),
UpdatedAt
:
time
.
Now
(),
})
return
&
ProxyTestResult
{
Success
:
false
,
Message
:
err
.
Error
(),
},
nil
}
latency
:=
latencyMs
s
.
saveProxyLatency
(
ctx
,
id
,
&
ProxyLatencyInfo
{
Success
:
true
,
LatencyMs
:
&
latency
,
Message
:
"Proxy is accessible"
,
IPAddress
:
exitInfo
.
IP
,
Country
:
exitInfo
.
Country
,
CountryCode
:
exitInfo
.
CountryCode
,
Region
:
exitInfo
.
Region
,
City
:
exitInfo
.
City
,
UpdatedAt
:
time
.
Now
(),
})
return
&
ProxyTestResult
{
Success
:
true
,
Message
:
"Proxy is accessible"
,
...
...
@@ -1204,9 +1380,38 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
City
:
exitInfo
.
City
,
Region
:
exitInfo
.
Region
,
Country
:
exitInfo
.
Country
,
CountryCode
:
exitInfo
.
CountryCode
,
},
nil
}
func
(
s
*
adminServiceImpl
)
probeProxyLatency
(
ctx
context
.
Context
,
proxy
*
Proxy
)
{
if
s
.
proxyProber
==
nil
||
proxy
==
nil
{
return
}
exitInfo
,
latencyMs
,
err
:=
s
.
proxyProber
.
ProbeProxy
(
ctx
,
proxy
.
URL
())
if
err
!=
nil
{
s
.
saveProxyLatency
(
ctx
,
proxy
.
ID
,
&
ProxyLatencyInfo
{
Success
:
false
,
Message
:
err
.
Error
(),
UpdatedAt
:
time
.
Now
(),
})
return
}
latency
:=
latencyMs
s
.
saveProxyLatency
(
ctx
,
proxy
.
ID
,
&
ProxyLatencyInfo
{
Success
:
true
,
LatencyMs
:
&
latency
,
Message
:
"Proxy is accessible"
,
IPAddress
:
exitInfo
.
IP
,
Country
:
exitInfo
.
Country
,
CountryCode
:
exitInfo
.
CountryCode
,
Region
:
exitInfo
.
Region
,
City
:
exitInfo
.
City
,
UpdatedAt
:
time
.
Now
(),
})
}
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
// 如果存在混合,返回错误提示用户确认
func
(
s
*
adminServiceImpl
)
checkMixedChannelRisk
(
ctx
context
.
Context
,
currentAccountID
int64
,
currentAccountPlatform
string
,
groupIDs
[]
int64
)
error
{
...
...
@@ -1256,6 +1461,51 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return
nil
}
func
(
s
*
adminServiceImpl
)
attachProxyLatency
(
ctx
context
.
Context
,
proxies
[]
ProxyWithAccountCount
)
{
if
s
.
proxyLatencyCache
==
nil
||
len
(
proxies
)
==
0
{
return
}
ids
:=
make
([]
int64
,
0
,
len
(
proxies
))
for
i
:=
range
proxies
{
ids
=
append
(
ids
,
proxies
[
i
]
.
ID
)
}
latencies
,
err
:=
s
.
proxyLatencyCache
.
GetProxyLatencies
(
ctx
,
ids
)
if
err
!=
nil
{
log
.
Printf
(
"Warning: load proxy latency cache failed: %v"
,
err
)
return
}
for
i
:=
range
proxies
{
info
:=
latencies
[
proxies
[
i
]
.
ID
]
if
info
==
nil
{
continue
}
if
info
.
Success
{
proxies
[
i
]
.
LatencyStatus
=
"success"
proxies
[
i
]
.
LatencyMs
=
info
.
LatencyMs
}
else
{
proxies
[
i
]
.
LatencyStatus
=
"failed"
}
proxies
[
i
]
.
LatencyMessage
=
info
.
Message
proxies
[
i
]
.
IPAddress
=
info
.
IPAddress
proxies
[
i
]
.
Country
=
info
.
Country
proxies
[
i
]
.
CountryCode
=
info
.
CountryCode
proxies
[
i
]
.
Region
=
info
.
Region
proxies
[
i
]
.
City
=
info
.
City
}
}
func
(
s
*
adminServiceImpl
)
saveProxyLatency
(
ctx
context
.
Context
,
proxyID
int64
,
info
*
ProxyLatencyInfo
)
{
if
s
.
proxyLatencyCache
==
nil
||
info
==
nil
{
return
}
if
err
:=
s
.
proxyLatencyCache
.
SetProxyLatency
(
ctx
,
proxyID
,
info
);
err
!=
nil
{
log
.
Printf
(
"Warning: store proxy latency cache failed: %v"
,
err
)
}
}
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func
getAccountPlatform
(
accountPlatform
string
)
string
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
accountPlatform
))
{
...
...
backend/internal/service/admin_service_bulk_update_test.go
0 → 100644
View file @
b9b4db3d
//go:build unit
package
service
import
(
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
type
accountRepoStubForBulkUpdate
struct
{
accountRepoStub
bulkUpdateErr
error
bulkUpdateIDs
[]
int64
bindGroupErrByID
map
[
int64
]
error
}
func
(
s
*
accountRepoStubForBulkUpdate
)
BulkUpdate
(
_
context
.
Context
,
ids
[]
int64
,
_
AccountBulkUpdate
)
(
int64
,
error
)
{
s
.
bulkUpdateIDs
=
append
([]
int64
{},
ids
...
)
if
s
.
bulkUpdateErr
!=
nil
{
return
0
,
s
.
bulkUpdateErr
}
return
int64
(
len
(
ids
)),
nil
}
func
(
s
*
accountRepoStubForBulkUpdate
)
BindGroups
(
_
context
.
Context
,
accountID
int64
,
_
[]
int64
)
error
{
if
err
,
ok
:=
s
.
bindGroupErrByID
[
accountID
];
ok
{
return
err
}
return
nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func
TestAdminService_BulkUpdateAccounts_AllSuccessIDs
(
t
*
testing
.
T
)
{
repo
:=
&
accountRepoStubForBulkUpdate
{}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
}
schedulable
:=
true
input
:=
&
BulkUpdateAccountsInput
{
AccountIDs
:
[]
int64
{
1
,
2
,
3
},
Schedulable
:
&
schedulable
,
}
result
,
err
:=
svc
.
BulkUpdateAccounts
(
context
.
Background
(),
input
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
3
,
result
.
Success
)
require
.
Equal
(
t
,
0
,
result
.
Failed
)
require
.
ElementsMatch
(
t
,
[]
int64
{
1
,
2
,
3
},
result
.
SuccessIDs
)
require
.
Empty
(
t
,
result
.
FailedIDs
)
require
.
Len
(
t
,
result
.
Results
,
3
)
}
// TestAdminService_BulkUpdateAccounts_PartialFailureIDs 验证部分失败时 success_ids/failed_ids 正确。
func
TestAdminService_BulkUpdateAccounts_PartialFailureIDs
(
t
*
testing
.
T
)
{
repo
:=
&
accountRepoStubForBulkUpdate
{
bindGroupErrByID
:
map
[
int64
]
error
{
2
:
errors
.
New
(
"bind failed"
),
},
}
svc
:=
&
adminServiceImpl
{
accountRepo
:
repo
}
groupIDs
:=
[]
int64
{
10
}
schedulable
:=
false
input
:=
&
BulkUpdateAccountsInput
{
AccountIDs
:
[]
int64
{
1
,
2
,
3
},
GroupIDs
:
&
groupIDs
,
Schedulable
:
&
schedulable
,
SkipMixedChannelCheck
:
true
,
}
result
,
err
:=
svc
.
BulkUpdateAccounts
(
context
.
Background
(),
input
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
2
,
result
.
Success
)
require
.
Equal
(
t
,
1
,
result
.
Failed
)
require
.
ElementsMatch
(
t
,
[]
int64
{
1
,
3
},
result
.
SuccessIDs
)
require
.
ElementsMatch
(
t
,
[]
int64
{
2
},
result
.
FailedIDs
)
require
.
Len
(
t
,
result
.
Results
,
3
)
}
backend/internal/service/admin_service_delete_test.go
View file @
b9b4db3d
...
...
@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
panic
(
"unexpected GetByID call"
)
}
func
(
s
*
groupRepoStub
)
GetByIDLite
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
panic
(
"unexpected GetByIDLite call"
)
}
func
(
s
*
groupRepoStub
)
Update
(
ctx
context
.
Context
,
group
*
Group
)
error
{
panic
(
"unexpected Update call"
)
}
...
...
@@ -150,6 +154,8 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
type
proxyRepoStub
struct
{
deleteErr
error
countErr
error
accountCount
int64
deletedIDs
[]
int64
}
...
...
@@ -195,7 +201,14 @@ func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, p
}
func
(
s
*
proxyRepoStub
)
CountAccountsByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
(
int64
,
error
)
{
panic
(
"unexpected CountAccountsByProxyID call"
)
if
s
.
countErr
!=
nil
{
return
0
,
s
.
countErr
}
return
s
.
accountCount
,
nil
}
func
(
s
*
proxyRepoStub
)
ListAccountSummariesByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
([]
ProxyAccountSummary
,
error
)
{
panic
(
"unexpected ListAccountSummariesByProxyID call"
)
}
type
redeemRepoStub
struct
{
...
...
@@ -405,6 +418,15 @@ func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
require
.
Equal
(
t
,
[]
int64
{
404
},
repo
.
deletedIDs
)
}
func
TestAdminService_DeleteProxy_InUse
(
t
*
testing
.
T
)
{
repo
:=
&
proxyRepoStub
{
accountCount
:
2
}
svc
:=
&
adminServiceImpl
{
proxyRepo
:
repo
}
err
:=
svc
.
DeleteProxy
(
context
.
Background
(),
77
)
require
.
ErrorIs
(
t
,
err
,
ErrProxyInUse
)
require
.
Empty
(
t
,
repo
.
deletedIDs
)
}
func
TestAdminService_DeleteProxy_Error
(
t
*
testing
.
T
)
{
deleteErr
:=
errors
.
New
(
"delete failed"
)
repo
:=
&
proxyRepoStub
{
deleteErr
:
deleteErr
}
...
...
backend/internal/service/admin_service_group_test.go
View file @
b9b4db3d
...
...
@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
return
s
.
getByID
,
nil
}
func
(
s
*
groupRepoStubForAdmin
)
GetByIDLite
(
_
context
.
Context
,
_
int64
)
(
*
Group
,
error
)
{
if
s
.
getErr
!=
nil
{
return
nil
,
s
.
getErr
}
return
s
.
getByID
,
nil
}
func
(
s
*
groupRepoStubForAdmin
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
...
...
@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
require
.
True
(
t
,
*
repo
.
listWithFiltersIsExclusive
)
})
}
func
TestAdminService_ValidateFallbackGroup_DetectsCycle
(
t
*
testing
.
T
)
{
groupID
:=
int64
(
1
)
fallbackID
:=
int64
(
2
)
repo
:=
&
groupRepoStubForFallbackCycle
{
groups
:
map
[
int64
]
*
Group
{
groupID
:
{
ID
:
groupID
,
FallbackGroupID
:
&
fallbackID
,
},
fallbackID
:
{
ID
:
fallbackID
,
FallbackGroupID
:
&
groupID
,
},
},
}
svc
:=
&
adminServiceImpl
{
groupRepo
:
repo
}
err
:=
svc
.
validateFallbackGroup
(
context
.
Background
(),
groupID
,
fallbackID
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"fallback group cycle"
)
}
type
groupRepoStubForFallbackCycle
struct
{
groups
map
[
int64
]
*
Group
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Create
(
_
context
.
Context
,
_
*
Group
)
error
{
panic
(
"unexpected Create call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Update
(
_
context
.
Context
,
_
*
Group
)
error
{
panic
(
"unexpected Update call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
return
s
.
GetByIDLite
(
ctx
,
id
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetByIDLite
(
_
context
.
Context
,
id
int64
)
(
*
Group
,
error
)
{
if
g
,
ok
:=
s
.
groups
[
id
];
ok
{
return
g
,
nil
}
return
nil
,
ErrGroupNotFound
}
func
(
s
*
groupRepoStubForFallbackCycle
)
Delete
(
_
context
.
Context
,
_
int64
)
error
{
panic
(
"unexpected Delete call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteCascade
(
_
context
.
Context
,
_
int64
)
([]
int64
,
error
)
{
panic
(
"unexpected DeleteCascade call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
List
(
_
context
.
Context
,
_
pagination
.
PaginationParams
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected List call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListWithFilters
(
_
context
.
Context
,
_
pagination
.
PaginationParams
,
_
,
_
,
_
string
,
_
*
bool
)
([]
Group
,
*
pagination
.
PaginationResult
,
error
)
{
panic
(
"unexpected ListWithFilters call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListActive
(
_
context
.
Context
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActive call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ListActiveByPlatform
(
_
context
.
Context
,
_
string
)
([]
Group
,
error
)
{
panic
(
"unexpected ListActiveByPlatform call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
ExistsByName
(
_
context
.
Context
,
_
string
)
(
bool
,
error
)
{
panic
(
"unexpected ExistsByName call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
GetAccountCount
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected GetAccountCount call"
)
}
func
(
s
*
groupRepoStubForFallbackCycle
)
DeleteAccountGroupsByGroupID
(
_
context
.
Context
,
_
int64
)
(
int64
,
error
)
{
panic
(
"unexpected DeleteAccountGroupsByGroupID call"
)
}
backend/internal/service/admin_service_update_balance_test.go
0 → 100644
View file @
b9b4db3d
//go:build unit
package
service
import
(
"context"
"testing"
"github.com/stretchr/testify/require"
)
type
balanceUserRepoStub
struct
{
*
userRepoStub
updateErr
error
updated
[]
*
User
}
func
(
s
*
balanceUserRepoStub
)
Update
(
ctx
context
.
Context
,
user
*
User
)
error
{
if
s
.
updateErr
!=
nil
{
return
s
.
updateErr
}
if
user
==
nil
{
return
nil
}
clone
:=
*
user
s
.
updated
=
append
(
s
.
updated
,
&
clone
)
if
s
.
userRepoStub
!=
nil
{
s
.
userRepoStub
.
user
=
&
clone
}
return
nil
}
type
balanceRedeemRepoStub
struct
{
*
redeemRepoStub
created
[]
*
RedeemCode
}
func
(
s
*
balanceRedeemRepoStub
)
Create
(
ctx
context
.
Context
,
code
*
RedeemCode
)
error
{
if
code
==
nil
{
return
nil
}
clone
:=
*
code
s
.
created
=
append
(
s
.
created
,
&
clone
)
return
nil
}
type
authCacheInvalidatorStub
struct
{
userIDs
[]
int64
groupIDs
[]
int64
keys
[]
string
}
func
(
s
*
authCacheInvalidatorStub
)
InvalidateAuthCacheByKey
(
ctx
context
.
Context
,
key
string
)
{
s
.
keys
=
append
(
s
.
keys
,
key
)
}
func
(
s
*
authCacheInvalidatorStub
)
InvalidateAuthCacheByUserID
(
ctx
context
.
Context
,
userID
int64
)
{
s
.
userIDs
=
append
(
s
.
userIDs
,
userID
)
}
func
(
s
*
authCacheInvalidatorStub
)
InvalidateAuthCacheByGroupID
(
ctx
context
.
Context
,
groupID
int64
)
{
s
.
groupIDs
=
append
(
s
.
groupIDs
,
groupID
)
}
func
TestAdminService_UpdateUserBalance_InvalidatesAuthCache
(
t
*
testing
.
T
)
{
baseRepo
:=
&
userRepoStub
{
user
:
&
User
{
ID
:
7
,
Balance
:
10
}}
repo
:=
&
balanceUserRepoStub
{
userRepoStub
:
baseRepo
}
redeemRepo
:=
&
balanceRedeemRepoStub
{
redeemRepoStub
:
&
redeemRepoStub
{}}
invalidator
:=
&
authCacheInvalidatorStub
{}
svc
:=
&
adminServiceImpl
{
userRepo
:
repo
,
redeemCodeRepo
:
redeemRepo
,
authCacheInvalidator
:
invalidator
,
}
_
,
err
:=
svc
.
UpdateUserBalance
(
context
.
Background
(),
7
,
5
,
"add"
,
""
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
int64
{
7
},
invalidator
.
userIDs
)
require
.
Len
(
t
,
redeemRepo
.
created
,
1
)
}
func
TestAdminService_UpdateUserBalance_NoChangeNoInvalidate
(
t
*
testing
.
T
)
{
baseRepo
:=
&
userRepoStub
{
user
:
&
User
{
ID
:
7
,
Balance
:
10
}}
repo
:=
&
balanceUserRepoStub
{
userRepoStub
:
baseRepo
}
redeemRepo
:=
&
balanceRedeemRepoStub
{
redeemRepoStub
:
&
redeemRepoStub
{}}
invalidator
:=
&
authCacheInvalidatorStub
{}
svc
:=
&
adminServiceImpl
{
userRepo
:
repo
,
redeemCodeRepo
:
redeemRepo
,
authCacheInvalidator
:
invalidator
,
}
_
,
err
:=
svc
.
UpdateUserBalance
(
context
.
Background
(),
7
,
10
,
"set"
,
""
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
invalidator
.
userIDs
)
require
.
Empty
(
t
,
redeemRepo
.
created
)
}
backend/internal/service/antigravity_gateway_service.go
View file @
b9b4db3d
...
...
@@ -12,6 +12,7 @@ import (
mathrand
"math/rand"
"net"
"net/http"
"os"
"strings"
"sync/atomic"
"time"
...
...
@@ -28,6 +29,8 @@ const (
antigravityRetryMaxDelay
=
16
*
time
.
Second
)
const
antigravityScopeRateLimitEnv
=
"GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
// antigravityRetryLoopParams 重试循环的参数
type
antigravityRetryLoopParams
struct
{
ctx
context
.
Context
...
...
@@ -38,7 +41,9 @@ type antigravityRetryLoopParams struct {
action
string
body
[]
byte
quotaScope
AntigravityQuotaScope
c
*
gin
.
Context
httpUpstream
HTTPUpstream
settingService
*
SettingService
handleError
func
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
}
...
...
@@ -56,6 +61,17 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
var
resp
*
http
.
Response
var
usedBaseURL
string
logBody
:=
p
.
settingService
!=
nil
&&
p
.
settingService
.
cfg
!=
nil
&&
p
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBody
maxBytes
:=
2048
if
p
.
settingService
!=
nil
&&
p
.
settingService
.
cfg
!=
nil
&&
p
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
>
0
{
maxBytes
=
p
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
}
getUpstreamDetail
:=
func
(
body
[]
byte
)
string
{
if
!
logBody
{
return
""
}
return
truncateString
(
string
(
body
),
maxBytes
)
}
urlFallbackLoop
:
for
urlIdx
,
baseURL
:=
range
availableURLs
{
...
...
@@ -73,8 +89,22 @@ urlFallbackLoop:
return
nil
,
err
}
// Capture upstream request body for ops retry of this attempt.
if
p
.
c
!=
nil
&&
len
(
p
.
body
)
>
0
{
p
.
c
.
Set
(
OpsUpstreamRequestBodyKey
,
string
(
p
.
body
))
}
resp
,
err
=
p
.
httpUpstream
.
Do
(
upstreamReq
,
p
.
proxyURL
,
p
.
account
.
ID
,
p
.
account
.
Concurrency
)
if
err
!=
nil
{
safeErr
:=
sanitizeUpstreamErrorMessage
(
err
.
Error
())
appendOpsUpstreamError
(
p
.
c
,
OpsUpstreamErrorEvent
{
Platform
:
p
.
account
.
Platform
,
AccountID
:
p
.
account
.
ID
,
AccountName
:
p
.
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"request_error"
,
Message
:
safeErr
,
})
if
shouldAntigravityFallbackToNextURL
(
err
,
0
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
antigravity
.
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"%s URL fallback (connection error): %s -> %s"
,
p
.
prefix
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
...
...
@@ -89,6 +119,7 @@ urlFallbackLoop:
continue
}
log
.
Printf
(
"%s status=request_failed retries_exhausted error=%v"
,
p
.
prefix
,
err
)
setOpsUpstreamError
(
p
.
c
,
0
,
safeErr
,
""
)
return
nil
,
fmt
.
Errorf
(
"upstream request failed after retries: %w"
,
err
)
}
...
...
@@ -99,13 +130,37 @@ urlFallbackLoop:
// "Resource has been exhausted" 是 URL 级别限流,切换 URL
if
isURLLevelRateLimit
(
respBody
)
&&
urlIdx
<
len
(
availableURLs
)
-
1
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
appendOpsUpstreamError
(
p
.
c
,
OpsUpstreamErrorEvent
{
Platform
:
p
.
account
.
Platform
,
AccountID
:
p
.
account
.
ID
,
AccountName
:
p
.
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"retry"
,
Message
:
upstreamMsg
,
Detail
:
getUpstreamDetail
(
respBody
),
})
antigravity
.
DefaultURLAvailability
.
MarkUnavailable
(
baseURL
)
log
.
Printf
(
"%s URL fallback (429): %s -> %s"
,
p
.
prefix
,
baseURL
,
availableURLs
[
urlIdx
+
1
])
log
.
Printf
(
"%s URL fallback (
HTTP
429): %s -> %s
body=%s
"
,
p
.
prefix
,
baseURL
,
availableURLs
[
urlIdx
+
1
]
,
truncateForLog
(
respBody
,
200
)
)
continue
urlFallbackLoop
}
// 账户/模型配额限流,重试 3 次(指数退避)
if
attempt
<
antigravityMaxRetries
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
appendOpsUpstreamError
(
p
.
c
,
OpsUpstreamErrorEvent
{
Platform
:
p
.
account
.
Platform
,
AccountID
:
p
.
account
.
ID
,
AccountName
:
p
.
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"retry"
,
Message
:
upstreamMsg
,
Detail
:
getUpstreamDetail
(
respBody
),
})
log
.
Printf
(
"%s status=429 retry=%d/%d body=%s"
,
p
.
prefix
,
attempt
,
antigravityMaxRetries
,
truncateForLog
(
respBody
,
200
))
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
...
...
@@ -131,6 +186,18 @@ urlFallbackLoop:
_
=
resp
.
Body
.
Close
()
if
attempt
<
antigravityMaxRetries
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
appendOpsUpstreamError
(
p
.
c
,
OpsUpstreamErrorEvent
{
Platform
:
p
.
account
.
Platform
,
AccountID
:
p
.
account
.
ID
,
AccountName
:
p
.
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"retry"
,
Message
:
upstreamMsg
,
Detail
:
getUpstreamDetail
(
respBody
),
})
log
.
Printf
(
"%s status=%d retry=%d/%d body=%s"
,
p
.
prefix
,
resp
.
StatusCode
,
attempt
,
antigravityMaxRetries
,
truncateForLog
(
respBody
,
500
))
if
!
sleepAntigravityBackoffWithContext
(
p
.
ctx
,
attempt
)
{
log
.
Printf
(
"%s status=context_canceled_during_backoff"
,
p
.
prefix
)
...
...
@@ -679,6 +746,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
proxyURL
=
account
.
Proxy
.
URL
()
}
// Sanitize thinking blocks (clean cache_control and flatten history thinking)
sanitizeThinkingBlocks
(
&
claudeReq
)
// 获取转换选项
// Antigravity 上游要求必须包含身份提示词,否则会返回 429
transformOpts
:=
s
.
getClaudeTransformOptions
(
ctx
)
...
...
@@ -690,6 +760,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return
nil
,
fmt
.
Errorf
(
"transform request: %w"
,
err
)
}
// Safety net: ensure no cache_control leaked into Gemini request
geminiBody
=
cleanCacheControlFromGeminiJSON
(
geminiBody
)
// Antigravity 上游只支持流式请求,统一使用 streamGenerateContent
// 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回
action
:=
"streamGenerateContent"
...
...
@@ -704,7 +777,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
action
:
action
,
body
:
geminiBody
,
quotaScope
:
quotaScope
,
c
:
c
,
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
handleError
:
s
.
handleUpstreamError
,
})
if
err
!=
nil
{
...
...
@@ -720,6 +795,28 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
if
resp
.
StatusCode
==
http
.
StatusBadRequest
&&
isSignatureRelatedError
(
respBody
)
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
logBody
:=
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBody
maxBytes
:=
2048
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
>
0
{
maxBytes
=
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
}
upstreamDetail
:=
""
if
logBody
{
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"signature_error"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
// Conservative two-stage fallback:
// 1) Disable top-level thinking + thinking->text
// 2) Only if still signature-related 400: also downgrade tool_use/tool_result to text.
...
...
@@ -753,6 +850,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
retryResp
,
retryErr
:=
s
.
httpUpstream
.
Do
(
retryReq
,
proxyURL
,
account
.
ID
,
account
.
Concurrency
)
if
retryErr
!=
nil
{
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
0
,
Kind
:
"signature_retry_request_error"
,
Message
:
sanitizeUpstreamErrorMessage
(
retryErr
.
Error
()),
})
log
.
Printf
(
"Antigravity account %d: signature retry request failed (%s): %v"
,
account
.
ID
,
stage
.
name
,
retryErr
)
continue
}
...
...
@@ -766,6 +871,26 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
retryBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
retryResp
.
Body
,
2
<<
20
))
_
=
retryResp
.
Body
.
Close
()
kind
:=
"signature_retry"
if
strings
.
TrimSpace
(
stage
.
name
)
!=
""
{
kind
=
"signature_retry_"
+
strings
.
ReplaceAll
(
stage
.
name
,
"+"
,
"_"
)
}
retryUpstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
retryBody
))
retryUpstreamMsg
=
sanitizeUpstreamErrorMessage
(
retryUpstreamMsg
)
retryUpstreamDetail
:=
""
if
logBody
{
retryUpstreamDetail
=
truncateString
(
string
(
retryBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
retryResp
.
StatusCode
,
UpstreamRequestID
:
retryResp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
kind
,
Message
:
retryUpstreamMsg
,
Detail
:
retryUpstreamDetail
,
})
// If this stage fixed the signature issue, we stop; otherwise we may try the next stage.
if
retryResp
.
StatusCode
!=
http
.
StatusBadRequest
||
!
isSignatureRelatedError
(
retryBody
)
{
...
...
@@ -793,10 +918,31 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
respBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
logBody
:=
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBody
maxBytes
:=
2048
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
>
0
{
maxBytes
=
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
}
upstreamDetail
:=
""
if
logBody
{
upstreamDetail
=
truncateString
(
string
(
respBody
),
maxBytes
)
}
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
resp
.
Header
.
Get
(
"x-request-id"
),
Kind
:
"failover"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
return
nil
,
s
.
writeMappedClaudeError
(
c
,
resp
.
StatusCode
,
respBody
)
return
nil
,
s
.
writeMappedClaudeError
(
c
,
account
,
resp
.
StatusCode
,
resp
.
Header
.
Get
(
"x-request-id"
),
respBody
)
}
}
...
...
@@ -879,6 +1025,143 @@ func extractAntigravityErrorMessage(body []byte) string {
return
""
}
// cleanCacheControlFromGeminiJSON removes cache_control from Gemini JSON (emergency fix)
// This should not be needed if transformation is correct, but serves as a safety net
func
cleanCacheControlFromGeminiJSON
(
body
[]
byte
)
[]
byte
{
// Try a more robust approach: parse and clean
var
data
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
body
,
&
data
);
err
!=
nil
{
log
.
Printf
(
"[Antigravity] Failed to parse Gemini JSON for cache_control cleaning: %v"
,
err
)
return
body
}
cleaned
:=
removeCacheControlFromAny
(
data
)
if
!
cleaned
{
return
body
}
if
result
,
err
:=
json
.
Marshal
(
data
);
err
==
nil
{
log
.
Printf
(
"[Antigravity] Successfully cleaned cache_control from Gemini JSON"
)
return
result
}
return
body
}
// removeCacheControlFromAny recursively removes cache_control fields
func
removeCacheControlFromAny
(
v
any
)
bool
{
cleaned
:=
false
switch
val
:=
v
.
(
type
)
{
case
map
[
string
]
any
:
for
k
,
child
:=
range
val
{
if
k
==
"cache_control"
{
delete
(
val
,
k
)
cleaned
=
true
}
else
if
removeCacheControlFromAny
(
child
)
{
cleaned
=
true
}
}
case
[]
any
:
for
_
,
item
:=
range
val
{
if
removeCacheControlFromAny
(
item
)
{
cleaned
=
true
}
}
}
return
cleaned
}
// sanitizeThinkingBlocks cleans cache_control and flattens history thinking blocks
// Thinking blocks do NOT support cache_control field (Anthropic API/Vertex AI requirement)
// Additionally, history thinking blocks are flattened to text to avoid upstream validation errors
func
sanitizeThinkingBlocks
(
req
*
antigravity
.
ClaudeRequest
)
{
if
req
==
nil
{
return
}
log
.
Printf
(
"[Antigravity] sanitizeThinkingBlocks: processing request with %d messages"
,
len
(
req
.
Messages
))
// Clean system blocks
if
len
(
req
.
System
)
>
0
{
var
systemBlocks
[]
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
req
.
System
,
&
systemBlocks
);
err
==
nil
{
for
i
:=
range
systemBlocks
{
if
blockType
,
_
:=
systemBlocks
[
i
][
"type"
]
.
(
string
);
blockType
==
"thinking"
||
systemBlocks
[
i
][
"thinking"
]
!=
nil
{
if
removeCacheControlFromAny
(
systemBlocks
[
i
])
{
log
.
Printf
(
"[Antigravity] Deep cleaned cache_control from thinking block in system[%d]"
,
i
)
}
}
}
// Marshal back
if
cleaned
,
err
:=
json
.
Marshal
(
systemBlocks
);
err
==
nil
{
req
.
System
=
cleaned
}
}
}
// Clean message content blocks and flatten history
lastMsgIdx
:=
len
(
req
.
Messages
)
-
1
for
msgIdx
:=
range
req
.
Messages
{
raw
:=
req
.
Messages
[
msgIdx
]
.
Content
if
len
(
raw
)
==
0
{
continue
}
// Try to parse as blocks array
var
blocks
[]
map
[
string
]
any
if
err
:=
json
.
Unmarshal
(
raw
,
&
blocks
);
err
!=
nil
{
continue
}
cleaned
:=
false
for
blockIdx
:=
range
blocks
{
blockType
,
_
:=
blocks
[
blockIdx
][
"type"
]
.
(
string
)
// Check for thinking blocks (typed or untyped)
if
blockType
==
"thinking"
||
blocks
[
blockIdx
][
"thinking"
]
!=
nil
{
// 1. Clean cache_control
if
removeCacheControlFromAny
(
blocks
[
blockIdx
])
{
log
.
Printf
(
"[Antigravity] Deep cleaned cache_control from thinking block in messages[%d].content[%d]"
,
msgIdx
,
blockIdx
)
cleaned
=
true
}
// 2. Flatten to text if it's a history message (not the last one)
if
msgIdx
<
lastMsgIdx
{
log
.
Printf
(
"[Antigravity] Flattening history thinking block to text at messages[%d].content[%d]"
,
msgIdx
,
blockIdx
)
// Extract thinking content
var
textContent
string
if
t
,
ok
:=
blocks
[
blockIdx
][
"thinking"
]
.
(
string
);
ok
{
textContent
=
t
}
else
{
// Fallback for non-string content (marshal it)
if
b
,
err
:=
json
.
Marshal
(
blocks
[
blockIdx
][
"thinking"
]);
err
==
nil
{
textContent
=
string
(
b
)
}
}
// Convert to text block
blocks
[
blockIdx
][
"type"
]
=
"text"
blocks
[
blockIdx
][
"text"
]
=
textContent
delete
(
blocks
[
blockIdx
],
"thinking"
)
delete
(
blocks
[
blockIdx
],
"signature"
)
delete
(
blocks
[
blockIdx
],
"cache_control"
)
// Ensure it's gone
cleaned
=
true
}
}
}
// Marshal back if modified
if
cleaned
{
if
marshaled
,
err
:=
json
.
Marshal
(
blocks
);
err
==
nil
{
req
.
Messages
[
msgIdx
]
.
Content
=
marshaled
}
}
}
}
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
...
...
@@ -1184,7 +1467,9 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
action
:
upstreamAction
,
body
:
wrappedBody
,
quotaScope
:
quotaScope
,
c
:
c
,
httpUpstream
:
s
.
httpUpstream
,
settingService
:
s
.
settingService
,
handleError
:
s
.
handleUpstreamError
,
})
if
err
!=
nil
{
...
...
@@ -1234,22 +1519,62 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
s
.
handleUpstreamError
(
ctx
,
prefix
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
,
quotaScope
)
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
// 解包并返回错误
requestID
:=
resp
.
Header
.
Get
(
"x-request-id"
)
if
requestID
!=
""
{
c
.
Header
(
"x-request-id"
,
requestID
)
}
unwrapped
,
_
:=
s
.
unwrapV1InternalResponse
(
respBody
)
unwrapped
,
unwrapErr
:=
s
.
unwrapV1InternalResponse
(
respBody
)
unwrappedForOps
:=
unwrapped
if
unwrapErr
!=
nil
||
len
(
unwrappedForOps
)
==
0
{
unwrappedForOps
=
respBody
}
upstreamMsg
:=
strings
.
TrimSpace
(
extractAntigravityErrorMessage
(
unwrappedForOps
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
logBody
:=
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBody
maxBytes
:=
2048
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
>
0
{
maxBytes
=
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
}
upstreamDetail
:=
""
if
logBody
{
upstreamDetail
=
truncateString
(
string
(
unwrappedForOps
),
maxBytes
)
}
// Always record upstream context for Ops error logs, even when we will failover.
setOpsUpstreamError
(
c
,
resp
.
StatusCode
,
upstreamMsg
,
upstreamDetail
)
if
s
.
shouldFailoverUpstreamError
(
resp
.
StatusCode
)
{
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
requestID
,
Kind
:
"failover"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
}
}
contentType
:=
resp
.
Header
.
Get
(
"Content-Type"
)
if
contentType
==
""
{
contentType
=
"application/json"
}
log
.
Printf
(
"[antigravity-Forward] upstream error status=%d body=%s"
,
resp
.
StatusCode
,
truncateForLog
(
respBody
,
500
))
c
.
Data
(
resp
.
StatusCode
,
contentType
,
unwrapped
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
resp
.
StatusCode
,
UpstreamRequestID
:
requestID
,
Kind
:
"http_error"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
log
.
Printf
(
"[antigravity-Forward] upstream error status=%d body=%s"
,
resp
.
StatusCode
,
truncateForLog
(
unwrappedForOps
,
500
))
c
.
Data
(
resp
.
StatusCode
,
contentType
,
unwrappedForOps
)
return
nil
,
fmt
.
Errorf
(
"antigravity upstream error: %d"
,
resp
.
StatusCode
)
}
...
...
@@ -1338,9 +1663,15 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
}
}
func
antigravityUseScopeRateLimit
()
bool
{
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
os
.
Getenv
(
antigravityScopeRateLimitEnv
)))
return
v
==
"1"
||
v
==
"true"
||
v
==
"yes"
||
v
==
"on"
}
func
(
s
*
AntigravityGatewayService
)
handleUpstreamError
(
ctx
context
.
Context
,
prefix
string
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
body
[]
byte
,
quotaScope
AntigravityQuotaScope
)
{
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
if
statusCode
==
429
{
useScopeLimit
:=
antigravityUseScopeRateLimit
()
&&
quotaScope
!=
""
resetAt
:=
ParseGeminiRateLimitResetTime
(
body
)
if
resetAt
==
nil
{
// 解析失败:使用配置的 fallback 时间,直接限流整个账户
...
...
@@ -1350,20 +1681,31 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
}
defaultDur
:=
time
.
Duration
(
fallbackMinutes
)
*
time
.
Minute
ra
:=
time
.
Now
()
.
Add
(
defaultDur
)
if
useScopeLimit
{
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_in=%v (fallback)"
,
prefix
,
quotaScope
,
defaultDur
)
if
err
:=
s
.
accountRepo
.
SetAntigravityQuotaScopeLimit
(
ctx
,
account
.
ID
,
quotaScope
,
ra
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 rate_limit_set_failed scope=%s error=%v"
,
prefix
,
quotaScope
,
err
)
}
}
else
{
log
.
Printf
(
"%s status=429 rate_limited account=%d reset_in=%v (fallback)"
,
prefix
,
account
.
ID
,
defaultDur
)
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
ra
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 rate_limit_set_failed account=%d error=%v"
,
prefix
,
account
.
ID
,
err
)
}
}
return
}
resetTime
:=
time
.
Unix
(
*
resetAt
,
0
)
if
useScopeLimit
{
log
.
Printf
(
"%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v"
,
prefix
,
quotaScope
,
resetTime
.
Format
(
"15:04:05"
),
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
if
quotaScope
==
""
{
return
}
if
err
:=
s
.
accountRepo
.
SetAntigravityQuotaScopeLimit
(
ctx
,
account
.
ID
,
quotaScope
,
resetTime
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 rate_limit_set_failed scope=%s error=%v"
,
prefix
,
quotaScope
,
err
)
}
}
else
{
log
.
Printf
(
"%s status=429 rate_limited account=%d reset_at=%v reset_in=%v"
,
prefix
,
account
.
ID
,
resetTime
.
Format
(
"15:04:05"
),
time
.
Until
(
resetTime
)
.
Truncate
(
time
.
Second
))
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetTime
);
err
!=
nil
{
log
.
Printf
(
"%s status=429 rate_limit_set_failed account=%d error=%v"
,
prefix
,
account
.
ID
,
err
)
}
}
return
}
// 其他错误码继续使用 rateLimitService
...
...
@@ -1533,6 +1875,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
continue
}
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
sendErrorEvent
(
"stream_timeout"
)
return
&
antigravityStreamResult
{
usage
:
usage
,
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
...
...
@@ -1824,9 +2167,36 @@ func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int,
return
fmt
.
Errorf
(
"%s"
,
message
)
}
func
(
s
*
AntigravityGatewayService
)
writeMappedClaudeError
(
c
*
gin
.
Context
,
upstreamStatus
int
,
body
[]
byte
)
error
{
// 记录上游错误详情便于调试
log
.
Printf
(
"[antigravity-Forward] upstream_error status=%d body=%s"
,
upstreamStatus
,
string
(
body
))
func
(
s
*
AntigravityGatewayService
)
writeMappedClaudeError
(
c
*
gin
.
Context
,
account
*
Account
,
upstreamStatus
int
,
upstreamRequestID
string
,
body
[]
byte
)
error
{
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
body
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
logBody
:=
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBody
maxBytes
:=
2048
if
s
.
settingService
!=
nil
&&
s
.
settingService
.
cfg
!=
nil
&&
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
>
0
{
maxBytes
=
s
.
settingService
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
}
upstreamDetail
:=
""
if
logBody
{
upstreamDetail
=
truncateString
(
string
(
body
),
maxBytes
)
}
setOpsUpstreamError
(
c
,
upstreamStatus
,
upstreamMsg
,
upstreamDetail
)
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
Platform
:
account
.
Platform
,
AccountID
:
account
.
ID
,
AccountName
:
account
.
Name
,
UpstreamStatusCode
:
upstreamStatus
,
UpstreamRequestID
:
upstreamRequestID
,
Kind
:
"http_error"
,
Message
:
upstreamMsg
,
Detail
:
upstreamDetail
,
})
// 记录上游错误详情便于排障(可选:由配置控制;不回显到客户端)
if
logBody
{
log
.
Printf
(
"[antigravity-Forward] upstream_error status=%d body=%s"
,
upstreamStatus
,
truncateForLog
(
body
,
maxBytes
))
}
var
statusCode
int
var
errType
,
errMsg
string
...
...
@@ -1862,7 +2232,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstr
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
errMsg
},
})
if
upstreamMsg
==
""
{
return
fmt
.
Errorf
(
"upstream error: %d"
,
upstreamStatus
)
}
return
fmt
.
Errorf
(
"upstream error: %d message=%s"
,
upstreamStatus
,
upstreamMsg
)
}
func
(
s
*
AntigravityGatewayService
)
writeGoogleError
(
c
*
gin
.
Context
,
status
int
,
message
string
)
error
{
...
...
@@ -2189,6 +2562,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
continue
}
log
.
Printf
(
"Stream data interval timeout (antigravity)"
)
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
sendErrorEvent
(
"stream_timeout"
)
return
&
antigravityStreamResult
{
usage
:
convertUsage
(
nil
),
firstTokenMs
:
firstTokenMs
},
fmt
.
Errorf
(
"stream data interval timeout"
)
}
...
...
backend/internal/service/antigravity_quota_scope.go
View file @
b9b4db3d
...
...
@@ -49,6 +49,9 @@ func (a *Account) IsSchedulableForModel(requestedModel string) bool {
if
!
a
.
IsSchedulable
()
{
return
false
}
if
a
.
isModelRateLimited
(
requestedModel
)
{
return
false
}
if
a
.
Platform
!=
PlatformAntigravity
{
return
true
}
...
...
backend/internal/service/antigravity_token_provider.go
View file @
b9b4db3d
...
...
@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return
""
,
errors
.
New
(
"not an antigravity oauth account"
)
}
cacheKey
:=
a
ntigravityTokenCacheKey
(
account
)
cacheKey
:=
A
ntigravityTokenCacheKey
(
account
)
// 1. 先尝试缓存
if
p
.
tokenCache
!=
nil
{
...
...
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return
accessToken
,
nil
}
func
a
ntigravityTokenCacheKey
(
account
*
Account
)
string
{
func
A
ntigravityTokenCacheKey
(
account
*
Account
)
string
{
projectID
:=
strings
.
TrimSpace
(
account
.
GetCredential
(
"project_id"
))
if
projectID
!=
""
{
return
"ag:"
+
projectID
...
...
backend/internal/service/api_key.go
View file @
b9b4db3d
...
...
@@ -9,6 +9,8 @@ type APIKey struct {
Name
string
GroupID
*
int64
Status
string
IPWhitelist
[]
string
IPBlacklist
[]
string
CreatedAt
time
.
Time
UpdatedAt
time
.
Time
User
*
User
...
...
backend/internal/service/api_key_auth_cache.go
0 → 100644
View file @
b9b4db3d
package
service
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type
APIKeyAuthSnapshot
struct
{
APIKeyID
int64
`json:"api_key_id"`
UserID
int64
`json:"user_id"`
GroupID
*
int64
`json:"group_id,omitempty"`
Status
string
`json:"status"`
IPWhitelist
[]
string
`json:"ip_whitelist,omitempty"`
IPBlacklist
[]
string
`json:"ip_blacklist,omitempty"`
User
APIKeyAuthUserSnapshot
`json:"user"`
Group
*
APIKeyAuthGroupSnapshot
`json:"group,omitempty"`
}
// APIKeyAuthUserSnapshot 用户快照
type
APIKeyAuthUserSnapshot
struct
{
ID
int64
`json:"id"`
Status
string
`json:"status"`
Role
string
`json:"role"`
Balance
float64
`json:"balance"`
Concurrency
int
`json:"concurrency"`
}
// APIKeyAuthGroupSnapshot 分组快照
type
APIKeyAuthGroupSnapshot
struct
{
ID
int64
`json:"id"`
Name
string
`json:"name"`
Platform
string
`json:"platform"`
Status
string
`json:"status"`
SubscriptionType
string
`json:"subscription_type"`
RateMultiplier
float64
`json:"rate_multiplier"`
DailyLimitUSD
*
float64
`json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD
*
float64
`json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD
*
float64
`json:"monthly_limit_usd,omitempty"`
ImagePrice1K
*
float64
`json:"image_price_1k,omitempty"`
ImagePrice2K
*
float64
`json:"image_price_2k,omitempty"`
ImagePrice4K
*
float64
`json:"image_price_4k,omitempty"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting
map
[
string
][]
int64
`json:"model_routing,omitempty"`
ModelRoutingEnabled
bool
`json:"model_routing_enabled"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
type
APIKeyAuthCacheEntry
struct
{
NotFound
bool
`json:"not_found"`
Snapshot
*
APIKeyAuthSnapshot
`json:"snapshot,omitempty"`
}
backend/internal/service/api_key_auth_cache_impl.go
0 → 100644
View file @
b9b4db3d
package
service
import
(
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/dgraph-io/ristretto"
)
type
apiKeyAuthCacheConfig
struct
{
l1Size
int
l1TTL
time
.
Duration
l2TTL
time
.
Duration
negativeTTL
time
.
Duration
jitterPercent
int
singleflight
bool
}
var
(
jitterRandMu
sync
.
Mutex
// 认证缓存抖动使用独立随机源,避免全局 Seed
jitterRand
=
rand
.
New
(
rand
.
NewSource
(
time
.
Now
()
.
UnixNano
()))
)
func
newAPIKeyAuthCacheConfig
(
cfg
*
config
.
Config
)
apiKeyAuthCacheConfig
{
if
cfg
==
nil
{
return
apiKeyAuthCacheConfig
{}
}
auth
:=
cfg
.
APIKeyAuth
return
apiKeyAuthCacheConfig
{
l1Size
:
auth
.
L1Size
,
l1TTL
:
time
.
Duration
(
auth
.
L1TTLSeconds
)
*
time
.
Second
,
l2TTL
:
time
.
Duration
(
auth
.
L2TTLSeconds
)
*
time
.
Second
,
negativeTTL
:
time
.
Duration
(
auth
.
NegativeTTLSeconds
)
*
time
.
Second
,
jitterPercent
:
auth
.
JitterPercent
,
singleflight
:
auth
.
Singleflight
,
}
}
func
(
c
apiKeyAuthCacheConfig
)
l1Enabled
()
bool
{
return
c
.
l1Size
>
0
&&
c
.
l1TTL
>
0
}
func
(
c
apiKeyAuthCacheConfig
)
l2Enabled
()
bool
{
return
c
.
l2TTL
>
0
}
func
(
c
apiKeyAuthCacheConfig
)
negativeEnabled
()
bool
{
return
c
.
negativeTTL
>
0
}
func
(
c
apiKeyAuthCacheConfig
)
jitterTTL
(
ttl
time
.
Duration
)
time
.
Duration
{
if
ttl
<=
0
{
return
ttl
}
if
c
.
jitterPercent
<=
0
{
return
ttl
}
percent
:=
c
.
jitterPercent
if
percent
>
100
{
percent
=
100
}
delta
:=
float64
(
percent
)
/
100
jitterRandMu
.
Lock
()
randVal
:=
jitterRand
.
Float64
()
jitterRandMu
.
Unlock
()
factor
:=
1
-
delta
+
randVal
*
(
2
*
delta
)
if
factor
<=
0
{
return
ttl
}
return
time
.
Duration
(
float64
(
ttl
)
*
factor
)
}
func
(
s
*
APIKeyService
)
initAuthCache
(
cfg
*
config
.
Config
)
{
s
.
authCfg
=
newAPIKeyAuthCacheConfig
(
cfg
)
if
!
s
.
authCfg
.
l1Enabled
()
{
return
}
cache
,
err
:=
ristretto
.
NewCache
(
&
ristretto
.
Config
{
NumCounters
:
int64
(
s
.
authCfg
.
l1Size
)
*
10
,
MaxCost
:
int64
(
s
.
authCfg
.
l1Size
),
BufferItems
:
64
,
})
if
err
!=
nil
{
return
}
s
.
authCacheL1
=
cache
}
func
(
s
*
APIKeyService
)
authCacheKey
(
key
string
)
string
{
sum
:=
sha256
.
Sum256
([]
byte
(
key
))
return
hex
.
EncodeToString
(
sum
[
:
])
}
func
(
s
*
APIKeyService
)
getAuthCacheEntry
(
ctx
context
.
Context
,
cacheKey
string
)
(
*
APIKeyAuthCacheEntry
,
bool
)
{
if
s
.
authCacheL1
!=
nil
{
if
val
,
ok
:=
s
.
authCacheL1
.
Get
(
cacheKey
);
ok
{
if
entry
,
ok
:=
val
.
(
*
APIKeyAuthCacheEntry
);
ok
{
return
entry
,
true
}
}
}
if
s
.
cache
==
nil
||
!
s
.
authCfg
.
l2Enabled
()
{
return
nil
,
false
}
entry
,
err
:=
s
.
cache
.
GetAuthCache
(
ctx
,
cacheKey
)
if
err
!=
nil
{
return
nil
,
false
}
s
.
setAuthCacheL1
(
cacheKey
,
entry
)
return
entry
,
true
}
func
(
s
*
APIKeyService
)
setAuthCacheL1
(
cacheKey
string
,
entry
*
APIKeyAuthCacheEntry
)
{
if
s
.
authCacheL1
==
nil
||
entry
==
nil
{
return
}
ttl
:=
s
.
authCfg
.
l1TTL
if
entry
.
NotFound
&&
s
.
authCfg
.
negativeTTL
>
0
&&
s
.
authCfg
.
negativeTTL
<
ttl
{
ttl
=
s
.
authCfg
.
negativeTTL
}
ttl
=
s
.
authCfg
.
jitterTTL
(
ttl
)
_
=
s
.
authCacheL1
.
SetWithTTL
(
cacheKey
,
entry
,
1
,
ttl
)
}
func
(
s
*
APIKeyService
)
setAuthCacheEntry
(
ctx
context
.
Context
,
cacheKey
string
,
entry
*
APIKeyAuthCacheEntry
,
ttl
time
.
Duration
)
{
if
entry
==
nil
{
return
}
s
.
setAuthCacheL1
(
cacheKey
,
entry
)
if
s
.
cache
==
nil
||
!
s
.
authCfg
.
l2Enabled
()
{
return
}
_
=
s
.
cache
.
SetAuthCache
(
ctx
,
cacheKey
,
entry
,
s
.
authCfg
.
jitterTTL
(
ttl
))
}
func
(
s
*
APIKeyService
)
deleteAuthCache
(
ctx
context
.
Context
,
cacheKey
string
)
{
if
s
.
authCacheL1
!=
nil
{
s
.
authCacheL1
.
Del
(
cacheKey
)
}
if
s
.
cache
==
nil
{
return
}
_
=
s
.
cache
.
DeleteAuthCache
(
ctx
,
cacheKey
)
}
func
(
s
*
APIKeyService
)
loadAuthCacheEntry
(
ctx
context
.
Context
,
key
,
cacheKey
string
)
(
*
APIKeyAuthCacheEntry
,
error
)
{
apiKey
,
err
:=
s
.
apiKeyRepo
.
GetByKeyForAuth
(
ctx
,
key
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
ErrAPIKeyNotFound
)
{
entry
:=
&
APIKeyAuthCacheEntry
{
NotFound
:
true
}
if
s
.
authCfg
.
negativeEnabled
()
{
s
.
setAuthCacheEntry
(
ctx
,
cacheKey
,
entry
,
s
.
authCfg
.
negativeTTL
)
}
return
entry
,
nil
}
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
err
)
}
apiKey
.
Key
=
key
snapshot
:=
s
.
snapshotFromAPIKey
(
apiKey
)
if
snapshot
==
nil
{
return
nil
,
fmt
.
Errorf
(
"get api key: %w"
,
ErrAPIKeyNotFound
)
}
entry
:=
&
APIKeyAuthCacheEntry
{
Snapshot
:
snapshot
}
s
.
setAuthCacheEntry
(
ctx
,
cacheKey
,
entry
,
s
.
authCfg
.
l2TTL
)
return
entry
,
nil
}
func
(
s
*
APIKeyService
)
applyAuthCacheEntry
(
key
string
,
entry
*
APIKeyAuthCacheEntry
)
(
*
APIKey
,
bool
,
error
)
{
if
entry
==
nil
{
return
nil
,
false
,
nil
}
if
entry
.
NotFound
{
return
nil
,
true
,
ErrAPIKeyNotFound
}
if
entry
.
Snapshot
==
nil
{
return
nil
,
false
,
nil
}
return
s
.
snapshotToAPIKey
(
key
,
entry
.
Snapshot
),
true
,
nil
}
func
(
s
*
APIKeyService
)
snapshotFromAPIKey
(
apiKey
*
APIKey
)
*
APIKeyAuthSnapshot
{
if
apiKey
==
nil
||
apiKey
.
User
==
nil
{
return
nil
}
snapshot
:=
&
APIKeyAuthSnapshot
{
APIKeyID
:
apiKey
.
ID
,
UserID
:
apiKey
.
UserID
,
GroupID
:
apiKey
.
GroupID
,
Status
:
apiKey
.
Status
,
IPWhitelist
:
apiKey
.
IPWhitelist
,
IPBlacklist
:
apiKey
.
IPBlacklist
,
User
:
APIKeyAuthUserSnapshot
{
ID
:
apiKey
.
User
.
ID
,
Status
:
apiKey
.
User
.
Status
,
Role
:
apiKey
.
User
.
Role
,
Balance
:
apiKey
.
User
.
Balance
,
Concurrency
:
apiKey
.
User
.
Concurrency
,
},
}
if
apiKey
.
Group
!=
nil
{
snapshot
.
Group
=
&
APIKeyAuthGroupSnapshot
{
ID
:
apiKey
.
Group
.
ID
,
Name
:
apiKey
.
Group
.
Name
,
Platform
:
apiKey
.
Group
.
Platform
,
Status
:
apiKey
.
Group
.
Status
,
SubscriptionType
:
apiKey
.
Group
.
SubscriptionType
,
RateMultiplier
:
apiKey
.
Group
.
RateMultiplier
,
DailyLimitUSD
:
apiKey
.
Group
.
DailyLimitUSD
,
WeeklyLimitUSD
:
apiKey
.
Group
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
apiKey
.
Group
.
MonthlyLimitUSD
,
ImagePrice1K
:
apiKey
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
apiKey
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
apiKey
.
Group
.
ImagePrice4K
,
ClaudeCodeOnly
:
apiKey
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
apiKey
.
Group
.
FallbackGroupID
,
ModelRouting
:
apiKey
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
apiKey
.
Group
.
ModelRoutingEnabled
,
}
}
return
snapshot
}
func
(
s
*
APIKeyService
)
snapshotToAPIKey
(
key
string
,
snapshot
*
APIKeyAuthSnapshot
)
*
APIKey
{
if
snapshot
==
nil
{
return
nil
}
apiKey
:=
&
APIKey
{
ID
:
snapshot
.
APIKeyID
,
UserID
:
snapshot
.
UserID
,
GroupID
:
snapshot
.
GroupID
,
Key
:
key
,
Status
:
snapshot
.
Status
,
IPWhitelist
:
snapshot
.
IPWhitelist
,
IPBlacklist
:
snapshot
.
IPBlacklist
,
User
:
&
User
{
ID
:
snapshot
.
User
.
ID
,
Status
:
snapshot
.
User
.
Status
,
Role
:
snapshot
.
User
.
Role
,
Balance
:
snapshot
.
User
.
Balance
,
Concurrency
:
snapshot
.
User
.
Concurrency
,
},
}
if
snapshot
.
Group
!=
nil
{
apiKey
.
Group
=
&
Group
{
ID
:
snapshot
.
Group
.
ID
,
Name
:
snapshot
.
Group
.
Name
,
Platform
:
snapshot
.
Group
.
Platform
,
Status
:
snapshot
.
Group
.
Status
,
Hydrated
:
true
,
SubscriptionType
:
snapshot
.
Group
.
SubscriptionType
,
RateMultiplier
:
snapshot
.
Group
.
RateMultiplier
,
DailyLimitUSD
:
snapshot
.
Group
.
DailyLimitUSD
,
WeeklyLimitUSD
:
snapshot
.
Group
.
WeeklyLimitUSD
,
MonthlyLimitUSD
:
snapshot
.
Group
.
MonthlyLimitUSD
,
ImagePrice1K
:
snapshot
.
Group
.
ImagePrice1K
,
ImagePrice2K
:
snapshot
.
Group
.
ImagePrice2K
,
ImagePrice4K
:
snapshot
.
Group
.
ImagePrice4K
,
ClaudeCodeOnly
:
snapshot
.
Group
.
ClaudeCodeOnly
,
FallbackGroupID
:
snapshot
.
Group
.
FallbackGroupID
,
ModelRouting
:
snapshot
.
Group
.
ModelRouting
,
ModelRoutingEnabled
:
snapshot
.
Group
.
ModelRoutingEnabled
,
}
}
return
apiKey
}
Prev
1
…
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment