Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
b9b4db3d
Commit
b9b4db3d
authored
Jan 17, 2026
by
song
Browse files
Merge upstream/main
parents
5a6f60a9
dae0d532
Changes
237
Show whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
237 of 237+
files are displayed.
Plain diff
Email patch
backend/internal/config/config_test.go
View file @
b9b4db3d
...
...
@@ -39,8 +39,8 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
if
cfg
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
!=
3
{
t
.
Fatalf
(
"StickySessionMaxWaiting = %d, want 3"
,
cfg
.
Gateway
.
Scheduling
.
StickySessionMaxWaiting
)
}
if
cfg
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
!=
45
*
time
.
Second
{
t
.
Fatalf
(
"StickySessionWaitTimeout = %v, want
45
s"
,
cfg
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
)
if
cfg
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
!=
120
*
time
.
Second
{
t
.
Fatalf
(
"StickySessionWaitTimeout = %v, want
120
s"
,
cfg
.
Gateway
.
Scheduling
.
StickySessionWaitTimeout
)
}
if
cfg
.
Gateway
.
Scheduling
.
FallbackWaitTimeout
!=
30
*
time
.
Second
{
t
.
Fatalf
(
"FallbackWaitTimeout = %v, want 30s"
,
cfg
.
Gateway
.
Scheduling
.
FallbackWaitTimeout
)
...
...
@@ -141,3 +141,142 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
t
.
Fatalf
(
"Validate() expected use_pkce error, got: %v"
,
err
)
}
}
func
TestLoadDefaultDashboardCacheConfig
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
if
!
cfg
.
Dashboard
.
Enabled
{
t
.
Fatalf
(
"Dashboard.Enabled = false, want true"
)
}
if
cfg
.
Dashboard
.
KeyPrefix
!=
"sub2api:"
{
t
.
Fatalf
(
"Dashboard.KeyPrefix = %q, want %q"
,
cfg
.
Dashboard
.
KeyPrefix
,
"sub2api:"
)
}
if
cfg
.
Dashboard
.
StatsFreshTTLSeconds
!=
15
{
t
.
Fatalf
(
"Dashboard.StatsFreshTTLSeconds = %d, want 15"
,
cfg
.
Dashboard
.
StatsFreshTTLSeconds
)
}
if
cfg
.
Dashboard
.
StatsTTLSeconds
!=
30
{
t
.
Fatalf
(
"Dashboard.StatsTTLSeconds = %d, want 30"
,
cfg
.
Dashboard
.
StatsTTLSeconds
)
}
if
cfg
.
Dashboard
.
StatsRefreshTimeoutSeconds
!=
30
{
t
.
Fatalf
(
"Dashboard.StatsRefreshTimeoutSeconds = %d, want 30"
,
cfg
.
Dashboard
.
StatsRefreshTimeoutSeconds
)
}
}
func
TestValidateDashboardCacheConfigEnabled
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
cfg
.
Dashboard
.
Enabled
=
true
cfg
.
Dashboard
.
StatsFreshTTLSeconds
=
10
cfg
.
Dashboard
.
StatsTTLSeconds
=
5
err
=
cfg
.
Validate
()
if
err
==
nil
{
t
.
Fatalf
(
"Validate() expected error for stats_fresh_ttl_seconds > stats_ttl_seconds, got nil"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"dashboard_cache.stats_fresh_ttl_seconds"
)
{
t
.
Fatalf
(
"Validate() expected stats_fresh_ttl_seconds error, got: %v"
,
err
)
}
}
func
TestValidateDashboardCacheConfigDisabled
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
cfg
.
Dashboard
.
Enabled
=
false
cfg
.
Dashboard
.
StatsTTLSeconds
=
-
1
err
=
cfg
.
Validate
()
if
err
==
nil
{
t
.
Fatalf
(
"Validate() expected error for negative stats_ttl_seconds, got nil"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"dashboard_cache.stats_ttl_seconds"
)
{
t
.
Fatalf
(
"Validate() expected stats_ttl_seconds error, got: %v"
,
err
)
}
}
func
TestLoadDefaultDashboardAggregationConfig
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
if
!
cfg
.
DashboardAgg
.
Enabled
{
t
.
Fatalf
(
"DashboardAgg.Enabled = false, want true"
)
}
if
cfg
.
DashboardAgg
.
IntervalSeconds
!=
60
{
t
.
Fatalf
(
"DashboardAgg.IntervalSeconds = %d, want 60"
,
cfg
.
DashboardAgg
.
IntervalSeconds
)
}
if
cfg
.
DashboardAgg
.
LookbackSeconds
!=
120
{
t
.
Fatalf
(
"DashboardAgg.LookbackSeconds = %d, want 120"
,
cfg
.
DashboardAgg
.
LookbackSeconds
)
}
if
cfg
.
DashboardAgg
.
BackfillEnabled
{
t
.
Fatalf
(
"DashboardAgg.BackfillEnabled = true, want false"
)
}
if
cfg
.
DashboardAgg
.
BackfillMaxDays
!=
31
{
t
.
Fatalf
(
"DashboardAgg.BackfillMaxDays = %d, want 31"
,
cfg
.
DashboardAgg
.
BackfillMaxDays
)
}
if
cfg
.
DashboardAgg
.
Retention
.
UsageLogsDays
!=
90
{
t
.
Fatalf
(
"DashboardAgg.Retention.UsageLogsDays = %d, want 90"
,
cfg
.
DashboardAgg
.
Retention
.
UsageLogsDays
)
}
if
cfg
.
DashboardAgg
.
Retention
.
HourlyDays
!=
180
{
t
.
Fatalf
(
"DashboardAgg.Retention.HourlyDays = %d, want 180"
,
cfg
.
DashboardAgg
.
Retention
.
HourlyDays
)
}
if
cfg
.
DashboardAgg
.
Retention
.
DailyDays
!=
730
{
t
.
Fatalf
(
"DashboardAgg.Retention.DailyDays = %d, want 730"
,
cfg
.
DashboardAgg
.
Retention
.
DailyDays
)
}
if
cfg
.
DashboardAgg
.
RecomputeDays
!=
2
{
t
.
Fatalf
(
"DashboardAgg.RecomputeDays = %d, want 2"
,
cfg
.
DashboardAgg
.
RecomputeDays
)
}
}
func
TestValidateDashboardAggregationConfigDisabled
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
cfg
.
DashboardAgg
.
Enabled
=
false
cfg
.
DashboardAgg
.
IntervalSeconds
=
-
1
err
=
cfg
.
Validate
()
if
err
==
nil
{
t
.
Fatalf
(
"Validate() expected error for negative dashboard_aggregation.interval_seconds, got nil"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"dashboard_aggregation.interval_seconds"
)
{
t
.
Fatalf
(
"Validate() expected interval_seconds error, got: %v"
,
err
)
}
}
func
TestValidateDashboardAggregationBackfillMaxDays
(
t
*
testing
.
T
)
{
viper
.
Reset
()
cfg
,
err
:=
Load
()
if
err
!=
nil
{
t
.
Fatalf
(
"Load() error: %v"
,
err
)
}
cfg
.
DashboardAgg
.
BackfillEnabled
=
true
cfg
.
DashboardAgg
.
BackfillMaxDays
=
0
err
=
cfg
.
Validate
()
if
err
==
nil
{
t
.
Fatalf
(
"Validate() expected error for dashboard_aggregation.backfill_max_days, got nil"
)
}
if
!
strings
.
Contains
(
err
.
Error
(),
"dashboard_aggregation.backfill_max_days"
)
{
t
.
Fatalf
(
"Validate() expected backfill_max_days error, got: %v"
,
err
)
}
}
backend/internal/handler/admin/account_handler.go
View file @
b9b4db3d
...
...
@@ -44,6 +44,7 @@ type AccountHandler struct {
accountTestService
*
service
.
AccountTestService
concurrencyService
*
service
.
ConcurrencyService
crsSyncService
*
service
.
CRSSyncService
sessionLimitCache
service
.
SessionLimitCache
}
// NewAccountHandler creates a new admin account handler
...
...
@@ -58,6 +59,7 @@ func NewAccountHandler(
accountTestService
*
service
.
AccountTestService
,
concurrencyService
*
service
.
ConcurrencyService
,
crsSyncService
*
service
.
CRSSyncService
,
sessionLimitCache
service
.
SessionLimitCache
,
)
*
AccountHandler
{
return
&
AccountHandler
{
adminService
:
adminService
,
...
...
@@ -70,6 +72,7 @@ func NewAccountHandler(
accountTestService
:
accountTestService
,
concurrencyService
:
concurrencyService
,
crsSyncService
:
crsSyncService
,
sessionLimitCache
:
sessionLimitCache
,
}
}
...
...
@@ -84,6 +87,7 @@ type CreateAccountRequest struct {
ProxyID
*
int64
`json:"proxy_id"`
Concurrency
int
`json:"concurrency"`
Priority
int
`json:"priority"`
RateMultiplier
*
float64
`json:"rate_multiplier"`
GroupIDs
[]
int64
`json:"group_ids"`
ExpiresAt
*
int64
`json:"expires_at"`
AutoPauseOnExpired
*
bool
`json:"auto_pause_on_expired"`
...
...
@@ -101,6 +105,7 @@ type UpdateAccountRequest struct {
ProxyID
*
int64
`json:"proxy_id"`
Concurrency
*
int
`json:"concurrency"`
Priority
*
int
`json:"priority"`
RateMultiplier
*
float64
`json:"rate_multiplier"`
Status
string
`json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs
*
[]
int64
`json:"group_ids"`
ExpiresAt
*
int64
`json:"expires_at"`
...
...
@@ -115,6 +120,7 @@ type BulkUpdateAccountsRequest struct {
ProxyID
*
int64
`json:"proxy_id"`
Concurrency
*
int
`json:"concurrency"`
Priority
*
int
`json:"priority"`
RateMultiplier
*
float64
`json:"rate_multiplier"`
Status
string
`json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable
*
bool
`json:"schedulable"`
GroupIDs
*
[]
int64
`json:"group_ids"`
...
...
@@ -127,6 +133,9 @@ type BulkUpdateAccountsRequest struct {
type
AccountWithConcurrency
struct
{
*
dto
.
Account
CurrentConcurrency
int
`json:"current_concurrency"`
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost
*
float64
`json:"current_window_cost,omitempty"`
// 当前窗口费用
ActiveSessions
*
int
`json:"active_sessions,omitempty"`
// 当前活跃会话数
}
// List handles listing all accounts with pagination
...
...
@@ -161,15 +170,91 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts
=
make
(
map
[
int64
]
int
)
}
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs
:=
make
([]
int64
,
0
)
sessionLimitAccountIDs
:=
make
([]
int64
,
0
)
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
acc
.
IsAnthropicOAuthOrSetupToken
()
{
if
acc
.
GetWindowCostLimit
()
>
0
{
windowCostAccountIDs
=
append
(
windowCostAccountIDs
,
acc
.
ID
)
}
if
acc
.
GetMaxSessions
()
>
0
{
sessionLimitAccountIDs
=
append
(
sessionLimitAccountIDs
,
acc
.
ID
)
}
}
}
// 并行获取窗口费用和活跃会话数
var
windowCosts
map
[
int64
]
float64
var
activeSessions
map
[
int64
]
int
// 获取活跃会话数(批量查询)
if
len
(
sessionLimitAccountIDs
)
>
0
&&
h
.
sessionLimitCache
!=
nil
{
activeSessions
,
_
=
h
.
sessionLimitCache
.
GetActiveSessionCountBatch
(
c
.
Request
.
Context
(),
sessionLimitAccountIDs
)
if
activeSessions
==
nil
{
activeSessions
=
make
(
map
[
int64
]
int
)
}
}
// 获取窗口费用(并行查询)
if
len
(
windowCostAccountIDs
)
>
0
{
windowCosts
=
make
(
map
[
int64
]
float64
)
var
mu
sync
.
Mutex
g
,
gctx
:=
errgroup
.
WithContext
(
c
.
Request
.
Context
())
g
.
SetLimit
(
10
)
// 限制并发数
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
if
!
acc
.
IsAnthropicOAuthOrSetupToken
()
||
acc
.
GetWindowCostLimit
()
<=
0
{
continue
}
accCopy
:=
acc
// 闭包捕获
g
.
Go
(
func
()
error
{
var
startTime
time
.
Time
if
accCopy
.
SessionWindowStart
!=
nil
{
startTime
=
*
accCopy
.
SessionWindowStart
}
else
{
startTime
=
time
.
Now
()
.
Add
(
-
5
*
time
.
Hour
)
}
stats
,
err
:=
h
.
accountUsageService
.
GetAccountWindowStats
(
gctx
,
accCopy
.
ID
,
startTime
)
if
err
==
nil
&&
stats
!=
nil
{
mu
.
Lock
()
windowCosts
[
accCopy
.
ID
]
=
stats
.
StandardCost
// 使用标准费用
mu
.
Unlock
()
}
return
nil
// 不返回错误,允许部分失败
})
}
_
=
g
.
Wait
()
}
// Build response with concurrency info
result
:=
make
([]
AccountWithConcurrency
,
len
(
accounts
))
for
i
:=
range
accounts
{
result
[
i
]
=
AccountWithConcurrency
{
Account
:
dto
.
AccountFromService
(
&
accounts
[
i
]),
CurrentConcurrency
:
concurrencyCounts
[
accounts
[
i
]
.
ID
],
acc
:=
&
accounts
[
i
]
item
:=
AccountWithConcurrency
{
Account
:
dto
.
AccountFromService
(
acc
),
CurrentConcurrency
:
concurrencyCounts
[
acc
.
ID
],
}
// 添加窗口费用(仅当启用时)
if
windowCosts
!=
nil
{
if
cost
,
ok
:=
windowCosts
[
acc
.
ID
];
ok
{
item
.
CurrentWindowCost
=
&
cost
}
}
// 添加活跃会话数(仅当启用时)
if
activeSessions
!=
nil
{
if
count
,
ok
:=
activeSessions
[
acc
.
ID
];
ok
{
item
.
ActiveSessions
=
&
count
}
}
result
[
i
]
=
item
}
response
.
Paginated
(
c
,
result
,
total
,
page
,
pageSize
)
}
...
...
@@ -199,6 +284,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
req
.
RateMultiplier
!=
nil
&&
*
req
.
RateMultiplier
<
0
{
response
.
BadRequest
(
c
,
"rate_multiplier must be >= 0"
)
return
}
// 确定是否跳过混合渠道检查
skipCheck
:=
req
.
ConfirmMixedChannelRisk
!=
nil
&&
*
req
.
ConfirmMixedChannelRisk
...
...
@@ -213,6 +302,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
ProxyID
:
req
.
ProxyID
,
Concurrency
:
req
.
Concurrency
,
Priority
:
req
.
Priority
,
RateMultiplier
:
req
.
RateMultiplier
,
GroupIDs
:
req
.
GroupIDs
,
ExpiresAt
:
req
.
ExpiresAt
,
AutoPauseOnExpired
:
req
.
AutoPauseOnExpired
,
...
...
@@ -258,6 +348,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
req
.
RateMultiplier
!=
nil
&&
*
req
.
RateMultiplier
<
0
{
response
.
BadRequest
(
c
,
"rate_multiplier must be >= 0"
)
return
}
// 确定是否跳过混合渠道检查
skipCheck
:=
req
.
ConfirmMixedChannelRisk
!=
nil
&&
*
req
.
ConfirmMixedChannelRisk
...
...
@@ -271,6 +365,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
ProxyID
:
req
.
ProxyID
,
Concurrency
:
req
.
Concurrency
,
// 指针类型,nil 表示未提供
Priority
:
req
.
Priority
,
// 指针类型,nil 表示未提供
RateMultiplier
:
req
.
RateMultiplier
,
Status
:
req
.
Status
,
GroupIDs
:
req
.
GroupIDs
,
ExpiresAt
:
req
.
ExpiresAt
,
...
...
@@ -682,6 +777,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
req
.
RateMultiplier
!=
nil
&&
*
req
.
RateMultiplier
<
0
{
response
.
BadRequest
(
c
,
"rate_multiplier must be >= 0"
)
return
}
// 确定是否跳过混合渠道检查
skipCheck
:=
req
.
ConfirmMixedChannelRisk
!=
nil
&&
*
req
.
ConfirmMixedChannelRisk
...
...
@@ -690,6 +789,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
req
.
ProxyID
!=
nil
||
req
.
Concurrency
!=
nil
||
req
.
Priority
!=
nil
||
req
.
RateMultiplier
!=
nil
||
req
.
Status
!=
""
||
req
.
Schedulable
!=
nil
||
req
.
GroupIDs
!=
nil
||
...
...
@@ -707,6 +807,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
ProxyID
:
req
.
ProxyID
,
Concurrency
:
req
.
Concurrency
,
Priority
:
req
.
Priority
,
RateMultiplier
:
req
.
RateMultiplier
,
Status
:
req
.
Status
,
Schedulable
:
req
.
Schedulable
,
GroupIDs
:
req
.
GroupIDs
,
...
...
backend/internal/handler/admin/dashboard_handler.go
View file @
b9b4db3d
package
admin
import
(
"errors"
"strconv"
"time"
...
...
@@ -14,13 +15,15 @@ import (
// DashboardHandler handles admin dashboard statistics
type
DashboardHandler
struct
{
dashboardService
*
service
.
DashboardService
aggregationService
*
service
.
DashboardAggregationService
startTime
time
.
Time
// Server start time for uptime calculation
}
// NewDashboardHandler creates a new admin dashboard handler
func
NewDashboardHandler
(
dashboardService
*
service
.
DashboardService
)
*
DashboardHandler
{
func
NewDashboardHandler
(
dashboardService
*
service
.
DashboardService
,
aggregationService
*
service
.
DashboardAggregationService
)
*
DashboardHandler
{
return
&
DashboardHandler
{
dashboardService
:
dashboardService
,
aggregationService
:
aggregationService
,
startTime
:
time
.
Now
(),
}
}
...
...
@@ -114,6 +117,58 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
// 性能指标
"rpm"
:
stats
.
Rpm
,
"tpm"
:
stats
.
Tpm
,
// 预聚合新鲜度
"hourly_active_users"
:
stats
.
HourlyActiveUsers
,
"stats_updated_at"
:
stats
.
StatsUpdatedAt
,
"stats_stale"
:
stats
.
StatsStale
,
})
}
type
DashboardAggregationBackfillRequest
struct
{
Start
string
`json:"start"`
End
string
`json:"end"`
}
// BackfillAggregation handles triggering aggregation backfill
// POST /api/v1/admin/dashboard/aggregation/backfill
func
(
h
*
DashboardHandler
)
BackfillAggregation
(
c
*
gin
.
Context
)
{
if
h
.
aggregationService
==
nil
{
response
.
InternalError
(
c
,
"Aggregation service not available"
)
return
}
var
req
DashboardAggregationBackfillRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request body"
)
return
}
start
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
req
.
Start
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid start time"
)
return
}
end
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
req
.
End
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid end time"
)
return
}
if
err
:=
h
.
aggregationService
.
TriggerBackfill
(
start
,
end
);
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrDashboardBackfillDisabled
)
{
response
.
Forbidden
(
c
,
"Backfill is disabled"
)
return
}
if
errors
.
Is
(
err
,
service
.
ErrDashboardBackfillTooLarge
)
{
response
.
BadRequest
(
c
,
"Backfill range too large"
)
return
}
response
.
InternalError
(
c
,
"Failed to trigger backfill"
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"status"
:
"accepted"
,
})
}
...
...
@@ -131,13 +186,16 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
, model, account_id, group_id, stream
func
(
h
*
DashboardHandler
)
GetUsageTrend
(
c
*
gin
.
Context
)
{
startTime
,
endTime
:=
parseTimeRange
(
c
)
granularity
:=
c
.
DefaultQuery
(
"granularity"
,
"day"
)
// Parse optional filter params
var
userID
,
apiKeyID
int64
var
userID
,
apiKeyID
,
accountID
,
groupID
int64
var
model
string
var
stream
*
bool
if
userIDStr
:=
c
.
Query
(
"user_id"
);
userIDStr
!=
""
{
if
id
,
err
:=
strconv
.
ParseInt
(
userIDStr
,
10
,
64
);
err
==
nil
{
userID
=
id
...
...
@@ -148,8 +206,26 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
apiKeyID
=
id
}
}
if
accountIDStr
:=
c
.
Query
(
"account_id"
);
accountIDStr
!=
""
{
if
id
,
err
:=
strconv
.
ParseInt
(
accountIDStr
,
10
,
64
);
err
==
nil
{
accountID
=
id
}
}
if
groupIDStr
:=
c
.
Query
(
"group_id"
);
groupIDStr
!=
""
{
if
id
,
err
:=
strconv
.
ParseInt
(
groupIDStr
,
10
,
64
);
err
==
nil
{
groupID
=
id
}
}
if
modelStr
:=
c
.
Query
(
"model"
);
modelStr
!=
""
{
model
=
modelStr
}
if
streamStr
:=
c
.
Query
(
"stream"
);
streamStr
!=
""
{
if
streamVal
,
err
:=
strconv
.
ParseBool
(
streamStr
);
err
==
nil
{
stream
=
&
streamVal
}
}
trend
,
err
:=
h
.
dashboardService
.
GetUsageTrendWithFilters
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
)
trend
,
err
:=
h
.
dashboardService
.
GetUsageTrendWithFilters
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
stream
)
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get usage trend"
)
return
...
...
@@ -165,12 +241,14 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
, account_id, group_id, stream
func
(
h
*
DashboardHandler
)
GetModelStats
(
c
*
gin
.
Context
)
{
startTime
,
endTime
:=
parseTimeRange
(
c
)
// Parse optional filter params
var
userID
,
apiKeyID
int64
var
userID
,
apiKeyID
,
accountID
,
groupID
int64
var
stream
*
bool
if
userIDStr
:=
c
.
Query
(
"user_id"
);
userIDStr
!=
""
{
if
id
,
err
:=
strconv
.
ParseInt
(
userIDStr
,
10
,
64
);
err
==
nil
{
userID
=
id
...
...
@@ -181,8 +259,23 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
apiKeyID
=
id
}
}
if
accountIDStr
:=
c
.
Query
(
"account_id"
);
accountIDStr
!=
""
{
if
id
,
err
:=
strconv
.
ParseInt
(
accountIDStr
,
10
,
64
);
err
==
nil
{
accountID
=
id
}
}
if
groupIDStr
:=
c
.
Query
(
"group_id"
);
groupIDStr
!=
""
{
if
id
,
err
:=
strconv
.
ParseInt
(
groupIDStr
,
10
,
64
);
err
==
nil
{
groupID
=
id
}
}
if
streamStr
:=
c
.
Query
(
"stream"
);
streamStr
!=
""
{
if
streamVal
,
err
:=
strconv
.
ParseBool
(
streamStr
);
err
==
nil
{
stream
=
&
streamVal
}
}
stats
,
err
:=
h
.
dashboardService
.
GetModelStatsWithFilters
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
userID
,
apiKeyID
)
stats
,
err
:=
h
.
dashboardService
.
GetModelStatsWithFilters
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
stream
)
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get model statistics"
)
return
...
...
backend/internal/handler/admin/group_handler.go
View file @
b9b4db3d
...
...
@@ -40,6 +40,9 @@ type CreateGroupRequest struct {
ImagePrice4K
*
float64
`json:"image_price_4k"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
`json:"model_routing"`
ModelRoutingEnabled
bool
`json:"model_routing_enabled"`
}
// UpdateGroupRequest represents update group request
...
...
@@ -60,6 +63,9 @@ type UpdateGroupRequest struct {
ImagePrice4K
*
float64
`json:"image_price_4k"`
ClaudeCodeOnly
*
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
`json:"model_routing"`
ModelRoutingEnabled
*
bool
`json:"model_routing_enabled"`
}
// List handles listing all groups with pagination
...
...
@@ -163,6 +169,8 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice4K
:
req
.
ImagePrice4K
,
ClaudeCodeOnly
:
req
.
ClaudeCodeOnly
,
FallbackGroupID
:
req
.
FallbackGroupID
,
ModelRouting
:
req
.
ModelRouting
,
ModelRoutingEnabled
:
req
.
ModelRoutingEnabled
,
})
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
...
...
@@ -203,6 +211,8 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice4K
:
req
.
ImagePrice4K
,
ClaudeCodeOnly
:
req
.
ClaudeCodeOnly
,
FallbackGroupID
:
req
.
FallbackGroupID
,
ModelRouting
:
req
.
ModelRouting
,
ModelRoutingEnabled
:
req
.
ModelRoutingEnabled
,
})
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
...
...
backend/internal/handler/admin/ops_alerts_handler.go
0 → 100644
View file @
b9b4db3d
package
admin
import
(
"encoding/json"
"fmt"
"math"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
)
var
validOpsAlertMetricTypes
=
[]
string
{
"success_rate"
,
"error_rate"
,
"upstream_error_rate"
,
"cpu_usage_percent"
,
"memory_usage_percent"
,
"concurrency_queue_depth"
,
}
var
validOpsAlertMetricTypeSet
=
func
()
map
[
string
]
struct
{}
{
set
:=
make
(
map
[
string
]
struct
{},
len
(
validOpsAlertMetricTypes
))
for
_
,
v
:=
range
validOpsAlertMetricTypes
{
set
[
v
]
=
struct
{}{}
}
return
set
}()
var
validOpsAlertOperators
=
[]
string
{
">"
,
"<"
,
">="
,
"<="
,
"=="
,
"!="
}
var
validOpsAlertOperatorSet
=
func
()
map
[
string
]
struct
{}
{
set
:=
make
(
map
[
string
]
struct
{},
len
(
validOpsAlertOperators
))
for
_
,
v
:=
range
validOpsAlertOperators
{
set
[
v
]
=
struct
{}{}
}
return
set
}()
var
validOpsAlertSeverities
=
[]
string
{
"P0"
,
"P1"
,
"P2"
,
"P3"
}
var
validOpsAlertSeveritySet
=
func
()
map
[
string
]
struct
{}
{
set
:=
make
(
map
[
string
]
struct
{},
len
(
validOpsAlertSeverities
))
for
_
,
v
:=
range
validOpsAlertSeverities
{
set
[
v
]
=
struct
{}{}
}
return
set
}()
type
opsAlertRuleValidatedInput
struct
{
Name
string
MetricType
string
Operator
string
Threshold
float64
Severity
string
WindowMinutes
int
SustainedMinutes
int
CooldownMinutes
int
Enabled
bool
NotifyEmail
bool
WindowProvided
bool
SustainedProvided
bool
CooldownProvided
bool
SeverityProvided
bool
EnabledProvided
bool
NotifyProvided
bool
}
func
isPercentOrRateMetric
(
metricType
string
)
bool
{
switch
metricType
{
case
"success_rate"
,
"error_rate"
,
"upstream_error_rate"
,
"cpu_usage_percent"
,
"memory_usage_percent"
:
return
true
default
:
return
false
}
}
func
validateOpsAlertRulePayload
(
raw
map
[
string
]
json
.
RawMessage
)
(
*
opsAlertRuleValidatedInput
,
error
)
{
if
raw
==
nil
{
return
nil
,
fmt
.
Errorf
(
"invalid request body"
)
}
requiredFields
:=
[]
string
{
"name"
,
"metric_type"
,
"operator"
,
"threshold"
}
for
_
,
field
:=
range
requiredFields
{
if
_
,
ok
:=
raw
[
field
];
!
ok
{
return
nil
,
fmt
.
Errorf
(
"%s is required"
,
field
)
}
}
var
name
string
if
err
:=
json
.
Unmarshal
(
raw
[
"name"
],
&
name
);
err
!=
nil
||
strings
.
TrimSpace
(
name
)
==
""
{
return
nil
,
fmt
.
Errorf
(
"name is required"
)
}
name
=
strings
.
TrimSpace
(
name
)
var
metricType
string
if
err
:=
json
.
Unmarshal
(
raw
[
"metric_type"
],
&
metricType
);
err
!=
nil
||
strings
.
TrimSpace
(
metricType
)
==
""
{
return
nil
,
fmt
.
Errorf
(
"metric_type is required"
)
}
metricType
=
strings
.
TrimSpace
(
metricType
)
if
_
,
ok
:=
validOpsAlertMetricTypeSet
[
metricType
];
!
ok
{
return
nil
,
fmt
.
Errorf
(
"metric_type must be one of: %s"
,
strings
.
Join
(
validOpsAlertMetricTypes
,
", "
))
}
var
operator
string
if
err
:=
json
.
Unmarshal
(
raw
[
"operator"
],
&
operator
);
err
!=
nil
||
strings
.
TrimSpace
(
operator
)
==
""
{
return
nil
,
fmt
.
Errorf
(
"operator is required"
)
}
operator
=
strings
.
TrimSpace
(
operator
)
if
_
,
ok
:=
validOpsAlertOperatorSet
[
operator
];
!
ok
{
return
nil
,
fmt
.
Errorf
(
"operator must be one of: %s"
,
strings
.
Join
(
validOpsAlertOperators
,
", "
))
}
var
threshold
float64
if
err
:=
json
.
Unmarshal
(
raw
[
"threshold"
],
&
threshold
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"threshold must be a number"
)
}
if
math
.
IsNaN
(
threshold
)
||
math
.
IsInf
(
threshold
,
0
)
{
return
nil
,
fmt
.
Errorf
(
"threshold must be a finite number"
)
}
if
isPercentOrRateMetric
(
metricType
)
{
if
threshold
<
0
||
threshold
>
100
{
return
nil
,
fmt
.
Errorf
(
"threshold must be between 0 and 100 for metric_type %s"
,
metricType
)
}
}
else
if
threshold
<
0
{
return
nil
,
fmt
.
Errorf
(
"threshold must be >= 0"
)
}
validated
:=
&
opsAlertRuleValidatedInput
{
Name
:
name
,
MetricType
:
metricType
,
Operator
:
operator
,
Threshold
:
threshold
,
}
if
v
,
ok
:=
raw
[
"severity"
];
ok
{
validated
.
SeverityProvided
=
true
var
sev
string
if
err
:=
json
.
Unmarshal
(
v
,
&
sev
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"severity must be a string"
)
}
sev
=
strings
.
ToUpper
(
strings
.
TrimSpace
(
sev
))
if
sev
!=
""
{
if
_
,
ok
:=
validOpsAlertSeveritySet
[
sev
];
!
ok
{
return
nil
,
fmt
.
Errorf
(
"severity must be one of: %s"
,
strings
.
Join
(
validOpsAlertSeverities
,
", "
))
}
validated
.
Severity
=
sev
}
}
if
validated
.
Severity
==
""
{
validated
.
Severity
=
"P2"
}
if
v
,
ok
:=
raw
[
"enabled"
];
ok
{
validated
.
EnabledProvided
=
true
if
err
:=
json
.
Unmarshal
(
v
,
&
validated
.
Enabled
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"enabled must be a boolean"
)
}
}
else
{
validated
.
Enabled
=
true
}
if
v
,
ok
:=
raw
[
"notify_email"
];
ok
{
validated
.
NotifyProvided
=
true
if
err
:=
json
.
Unmarshal
(
v
,
&
validated
.
NotifyEmail
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"notify_email must be a boolean"
)
}
}
else
{
validated
.
NotifyEmail
=
true
}
if
v
,
ok
:=
raw
[
"window_minutes"
];
ok
{
validated
.
WindowProvided
=
true
if
err
:=
json
.
Unmarshal
(
v
,
&
validated
.
WindowMinutes
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"window_minutes must be an integer"
)
}
switch
validated
.
WindowMinutes
{
case
1
,
5
,
60
:
default
:
return
nil
,
fmt
.
Errorf
(
"window_minutes must be one of: 1, 5, 60"
)
}
}
else
{
validated
.
WindowMinutes
=
1
}
if
v
,
ok
:=
raw
[
"sustained_minutes"
];
ok
{
validated
.
SustainedProvided
=
true
if
err
:=
json
.
Unmarshal
(
v
,
&
validated
.
SustainedMinutes
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"sustained_minutes must be an integer"
)
}
if
validated
.
SustainedMinutes
<
1
||
validated
.
SustainedMinutes
>
1440
{
return
nil
,
fmt
.
Errorf
(
"sustained_minutes must be between 1 and 1440"
)
}
}
else
{
validated
.
SustainedMinutes
=
1
}
if
v
,
ok
:=
raw
[
"cooldown_minutes"
];
ok
{
validated
.
CooldownProvided
=
true
if
err
:=
json
.
Unmarshal
(
v
,
&
validated
.
CooldownMinutes
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"cooldown_minutes must be an integer"
)
}
if
validated
.
CooldownMinutes
<
0
||
validated
.
CooldownMinutes
>
1440
{
return
nil
,
fmt
.
Errorf
(
"cooldown_minutes must be between 0 and 1440"
)
}
}
else
{
validated
.
CooldownMinutes
=
0
}
return
validated
,
nil
}
// ListAlertRules returns all ops alert rules.
// GET /api/v1/admin/ops/alert-rules
func
(
h
*
OpsHandler
)
ListAlertRules
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
rules
,
err
:=
h
.
opsService
.
ListAlertRules
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
rules
)
}
// CreateAlertRule creates an ops alert rule.
// POST /api/v1/admin/ops/alert-rules
func
(
h
*
OpsHandler
)
CreateAlertRule
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
var
raw
map
[
string
]
json
.
RawMessage
if
err
:=
c
.
ShouldBindBodyWith
(
&
raw
,
binding
.
JSON
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request body"
)
return
}
validated
,
err
:=
validateOpsAlertRulePayload
(
raw
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
var
rule
service
.
OpsAlertRule
if
err
:=
c
.
ShouldBindBodyWith
(
&
rule
,
binding
.
JSON
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request body"
)
return
}
rule
.
Name
=
validated
.
Name
rule
.
MetricType
=
validated
.
MetricType
rule
.
Operator
=
validated
.
Operator
rule
.
Threshold
=
validated
.
Threshold
rule
.
WindowMinutes
=
validated
.
WindowMinutes
rule
.
SustainedMinutes
=
validated
.
SustainedMinutes
rule
.
CooldownMinutes
=
validated
.
CooldownMinutes
rule
.
Severity
=
validated
.
Severity
rule
.
Enabled
=
validated
.
Enabled
rule
.
NotifyEmail
=
validated
.
NotifyEmail
created
,
err
:=
h
.
opsService
.
CreateAlertRule
(
c
.
Request
.
Context
(),
&
rule
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
created
)
}
// UpdateAlertRule updates an existing ops alert rule.
// PUT /api/v1/admin/ops/alert-rules/:id
func
(
h
*
OpsHandler
)
UpdateAlertRule
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid rule ID"
)
return
}
var
raw
map
[
string
]
json
.
RawMessage
if
err
:=
c
.
ShouldBindBodyWith
(
&
raw
,
binding
.
JSON
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request body"
)
return
}
validated
,
err
:=
validateOpsAlertRulePayload
(
raw
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
var
rule
service
.
OpsAlertRule
if
err
:=
c
.
ShouldBindBodyWith
(
&
rule
,
binding
.
JSON
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request body"
)
return
}
rule
.
ID
=
id
rule
.
Name
=
validated
.
Name
rule
.
MetricType
=
validated
.
MetricType
rule
.
Operator
=
validated
.
Operator
rule
.
Threshold
=
validated
.
Threshold
rule
.
WindowMinutes
=
validated
.
WindowMinutes
rule
.
SustainedMinutes
=
validated
.
SustainedMinutes
rule
.
CooldownMinutes
=
validated
.
CooldownMinutes
rule
.
Severity
=
validated
.
Severity
rule
.
Enabled
=
validated
.
Enabled
rule
.
NotifyEmail
=
validated
.
NotifyEmail
updated
,
err
:=
h
.
opsService
.
UpdateAlertRule
(
c
.
Request
.
Context
(),
&
rule
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
updated
)
}
// DeleteAlertRule deletes an ops alert rule.
// DELETE /api/v1/admin/ops/alert-rules/:id
func
(
h
*
OpsHandler
)
DeleteAlertRule
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid rule ID"
)
return
}
if
err
:=
h
.
opsService
.
DeleteAlertRule
(
c
.
Request
.
Context
(),
id
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"deleted"
:
true
})
}
// GetAlertEvent returns a single ops alert event.
// GET /api/v1/admin/ops/alert-events/:id
func
(
h
*
OpsHandler
)
GetAlertEvent
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid event ID"
)
return
}
ev
,
err
:=
h
.
opsService
.
GetAlertEventByID
(
c
.
Request
.
Context
(),
id
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
ev
)
}
// UpdateAlertEventStatus updates an ops alert event status.
// PUT /api/v1/admin/ops/alert-events/:id/status
func
(
h
*
OpsHandler
)
UpdateAlertEventStatus
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
id
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid event ID"
)
return
}
var
payload
struct
{
Status
string
`json:"status"`
}
if
err
:=
c
.
ShouldBindJSON
(
&
payload
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request body"
)
return
}
payload
.
Status
=
strings
.
TrimSpace
(
payload
.
Status
)
if
payload
.
Status
==
""
{
response
.
BadRequest
(
c
,
"Invalid status"
)
return
}
if
payload
.
Status
!=
service
.
OpsAlertStatusResolved
&&
payload
.
Status
!=
service
.
OpsAlertStatusManualResolved
{
response
.
BadRequest
(
c
,
"Invalid status"
)
return
}
var
resolvedAt
*
time
.
Time
if
payload
.
Status
==
service
.
OpsAlertStatusResolved
||
payload
.
Status
==
service
.
OpsAlertStatusManualResolved
{
now
:=
time
.
Now
()
.
UTC
()
resolvedAt
=
&
now
}
if
err
:=
h
.
opsService
.
UpdateAlertEventStatus
(
c
.
Request
.
Context
(),
id
,
payload
.
Status
,
resolvedAt
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"updated"
:
true
})
}
// ListAlertEvents lists recent ops alert events.
// GET /api/v1/admin/ops/alert-events
// CreateAlertSilence creates a scoped silence for ops alerts.
// POST /api/v1/admin/ops/alert-silences
func
(
h
*
OpsHandler
)
CreateAlertSilence
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
var
payload
struct
{
RuleID
int64
`json:"rule_id"`
Platform
string
`json:"platform"`
GroupID
*
int64
`json:"group_id"`
Region
*
string
`json:"region"`
Until
string
`json:"until"`
Reason
string
`json:"reason"`
}
if
err
:=
c
.
ShouldBindJSON
(
&
payload
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request body"
)
return
}
until
,
err
:=
time
.
Parse
(
time
.
RFC3339
,
strings
.
TrimSpace
(
payload
.
Until
))
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid until"
)
return
}
createdBy
:=
(
*
int64
)(
nil
)
if
subject
,
ok
:=
middleware
.
GetAuthSubjectFromContext
(
c
);
ok
{
uid
:=
subject
.
UserID
createdBy
=
&
uid
}
silence
:=
&
service
.
OpsAlertSilence
{
RuleID
:
payload
.
RuleID
,
Platform
:
strings
.
TrimSpace
(
payload
.
Platform
),
GroupID
:
payload
.
GroupID
,
Region
:
payload
.
Region
,
Until
:
until
,
Reason
:
strings
.
TrimSpace
(
payload
.
Reason
),
CreatedBy
:
createdBy
,
}
created
,
err
:=
h
.
opsService
.
CreateAlertSilence
(
c
.
Request
.
Context
(),
silence
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
created
)
}
func
(
h
*
OpsHandler
)
ListAlertEvents
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
limit
:=
20
if
raw
:=
strings
.
TrimSpace
(
c
.
Query
(
"limit"
));
raw
!=
""
{
n
,
err
:=
strconv
.
Atoi
(
raw
)
if
err
!=
nil
||
n
<=
0
{
response
.
BadRequest
(
c
,
"Invalid limit"
)
return
}
limit
=
n
}
filter
:=
&
service
.
OpsAlertEventFilter
{
Limit
:
limit
,
Status
:
strings
.
TrimSpace
(
c
.
Query
(
"status"
)),
Severity
:
strings
.
TrimSpace
(
c
.
Query
(
"severity"
)),
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"email_sent"
));
v
!=
""
{
vv
:=
strings
.
ToLower
(
v
)
switch
vv
{
case
"true"
,
"1"
:
b
:=
true
filter
.
EmailSent
=
&
b
case
"false"
,
"0"
:
b
:=
false
filter
.
EmailSent
=
&
b
default
:
response
.
BadRequest
(
c
,
"Invalid email_sent"
)
return
}
}
// Cursor pagination: both params must be provided together.
rawTS
:=
strings
.
TrimSpace
(
c
.
Query
(
"before_fired_at"
))
rawID
:=
strings
.
TrimSpace
(
c
.
Query
(
"before_id"
))
if
(
rawTS
==
""
)
!=
(
rawID
==
""
)
{
response
.
BadRequest
(
c
,
"before_fired_at and before_id must be provided together"
)
return
}
if
rawTS
!=
""
{
ts
,
err
:=
time
.
Parse
(
time
.
RFC3339Nano
,
rawTS
)
if
err
!=
nil
{
if
t2
,
err2
:=
time
.
Parse
(
time
.
RFC3339
,
rawTS
);
err2
==
nil
{
ts
=
t2
}
else
{
response
.
BadRequest
(
c
,
"Invalid before_fired_at"
)
return
}
}
filter
.
BeforeFiredAt
=
&
ts
}
if
rawID
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
rawID
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid before_id"
)
return
}
filter
.
BeforeID
=
&
id
}
// Optional global filter support (platform/group/time range).
if
platform
:=
strings
.
TrimSpace
(
c
.
Query
(
"platform"
));
platform
!=
""
{
filter
.
Platform
=
platform
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
filter
.
GroupID
=
&
id
}
if
startTime
,
endTime
,
err
:=
parseOpsTimeRange
(
c
,
"24h"
);
err
==
nil
{
// Only apply when explicitly provided to avoid surprising default narrowing.
if
strings
.
TrimSpace
(
c
.
Query
(
"start_time"
))
!=
""
||
strings
.
TrimSpace
(
c
.
Query
(
"end_time"
))
!=
""
||
strings
.
TrimSpace
(
c
.
Query
(
"time_range"
))
!=
""
{
filter
.
StartTime
=
&
startTime
filter
.
EndTime
=
&
endTime
}
}
else
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
events
,
err
:=
h
.
opsService
.
ListAlertEvents
(
c
.
Request
.
Context
(),
filter
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
events
)
}
backend/internal/handler/admin/ops_dashboard_handler.go
0 → 100644
View file @
b9b4db3d
package
admin
import
(
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GetDashboardOverview returns vNext ops dashboard overview (raw path).
// GET /api/v1/admin/ops/dashboard/overview
func
(
h
*
OpsHandler
)
GetDashboardOverview
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
startTime
,
endTime
,
err
:=
parseOpsTimeRange
(
c
,
"1h"
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
filter
:=
&
service
.
OpsDashboardFilter
{
StartTime
:
startTime
,
EndTime
:
endTime
,
Platform
:
strings
.
TrimSpace
(
c
.
Query
(
"platform"
)),
QueryMode
:
parseOpsQueryMode
(
c
),
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
filter
.
GroupID
=
&
id
}
data
,
err
:=
h
.
opsService
.
GetDashboardOverview
(
c
.
Request
.
Context
(),
filter
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
data
)
}
// GetDashboardThroughputTrend returns throughput time series (raw path).
// GET /api/v1/admin/ops/dashboard/throughput-trend
func
(
h
*
OpsHandler
)
GetDashboardThroughputTrend
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
startTime
,
endTime
,
err
:=
parseOpsTimeRange
(
c
,
"1h"
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
filter
:=
&
service
.
OpsDashboardFilter
{
StartTime
:
startTime
,
EndTime
:
endTime
,
Platform
:
strings
.
TrimSpace
(
c
.
Query
(
"platform"
)),
QueryMode
:
parseOpsQueryMode
(
c
),
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
filter
.
GroupID
=
&
id
}
bucketSeconds
:=
pickThroughputBucketSeconds
(
endTime
.
Sub
(
startTime
))
data
,
err
:=
h
.
opsService
.
GetThroughputTrend
(
c
.
Request
.
Context
(),
filter
,
bucketSeconds
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
data
)
}
// GetDashboardLatencyHistogram returns the latency distribution histogram (success requests).
// GET /api/v1/admin/ops/dashboard/latency-histogram
func
(
h
*
OpsHandler
)
GetDashboardLatencyHistogram
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
startTime
,
endTime
,
err
:=
parseOpsTimeRange
(
c
,
"1h"
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
filter
:=
&
service
.
OpsDashboardFilter
{
StartTime
:
startTime
,
EndTime
:
endTime
,
Platform
:
strings
.
TrimSpace
(
c
.
Query
(
"platform"
)),
QueryMode
:
parseOpsQueryMode
(
c
),
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
filter
.
GroupID
=
&
id
}
data
,
err
:=
h
.
opsService
.
GetLatencyHistogram
(
c
.
Request
.
Context
(),
filter
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
data
)
}
// GetDashboardErrorTrend returns error counts time series (raw path).
// GET /api/v1/admin/ops/dashboard/error-trend
func
(
h
*
OpsHandler
)
GetDashboardErrorTrend
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
startTime
,
endTime
,
err
:=
parseOpsTimeRange
(
c
,
"1h"
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
filter
:=
&
service
.
OpsDashboardFilter
{
StartTime
:
startTime
,
EndTime
:
endTime
,
Platform
:
strings
.
TrimSpace
(
c
.
Query
(
"platform"
)),
QueryMode
:
parseOpsQueryMode
(
c
),
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
filter
.
GroupID
=
&
id
}
bucketSeconds
:=
pickThroughputBucketSeconds
(
endTime
.
Sub
(
startTime
))
data
,
err
:=
h
.
opsService
.
GetErrorTrend
(
c
.
Request
.
Context
(),
filter
,
bucketSeconds
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
data
)
}
// GetDashboardErrorDistribution returns error distribution by status code (raw path).
// GET /api/v1/admin/ops/dashboard/error-distribution
func
(
h
*
OpsHandler
)
GetDashboardErrorDistribution
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
startTime
,
endTime
,
err
:=
parseOpsTimeRange
(
c
,
"1h"
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
filter
:=
&
service
.
OpsDashboardFilter
{
StartTime
:
startTime
,
EndTime
:
endTime
,
Platform
:
strings
.
TrimSpace
(
c
.
Query
(
"platform"
)),
QueryMode
:
parseOpsQueryMode
(
c
),
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
filter
.
GroupID
=
&
id
}
data
,
err
:=
h
.
opsService
.
GetErrorDistribution
(
c
.
Request
.
Context
(),
filter
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
data
)
}
func
pickThroughputBucketSeconds
(
window
time
.
Duration
)
int
{
// Keep buckets predictable and avoid huge responses.
switch
{
case
window
<=
2
*
time
.
Hour
:
return
60
case
window
<=
24
*
time
.
Hour
:
return
300
default
:
return
3600
}
}
func
parseOpsQueryMode
(
c
*
gin
.
Context
)
service
.
OpsQueryMode
{
if
c
==
nil
{
return
""
}
raw
:=
strings
.
TrimSpace
(
c
.
Query
(
"mode"
))
if
raw
==
""
{
// Empty means "use server default" (DB setting ops_query_mode_default).
return
""
}
return
service
.
ParseOpsQueryMode
(
raw
)
}
backend/internal/handler/admin/ops_handler.go
0 → 100644
View file @
b9b4db3d
package
admin
import
(
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type
OpsHandler
struct
{
opsService
*
service
.
OpsService
}
// GetErrorLogByID returns ops error log detail.
// GET /api/v1/admin/ops/errors/:id
func
(
h
*
OpsHandler
)
GetErrorLogByID
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
idStr
:=
strings
.
TrimSpace
(
c
.
Param
(
"id"
))
id
,
err
:=
strconv
.
ParseInt
(
idStr
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid error id"
)
return
}
detail
,
err
:=
h
.
opsService
.
GetErrorLogByID
(
c
.
Request
.
Context
(),
id
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
detail
)
}
const
(
opsListViewErrors
=
"errors"
opsListViewExcluded
=
"excluded"
opsListViewAll
=
"all"
)
func
parseOpsViewParam
(
c
*
gin
.
Context
)
string
{
if
c
==
nil
{
return
""
}
v
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
c
.
Query
(
"view"
)))
switch
v
{
case
""
,
opsListViewErrors
:
return
opsListViewErrors
case
opsListViewExcluded
:
return
opsListViewExcluded
case
opsListViewAll
:
return
opsListViewAll
default
:
return
opsListViewErrors
}
}
func
NewOpsHandler
(
opsService
*
service
.
OpsService
)
*
OpsHandler
{
return
&
OpsHandler
{
opsService
:
opsService
}
}
// GetErrorLogs lists ops error logs.
// GET /api/v1/admin/ops/errors
func
(
h
*
OpsHandler
)
GetErrorLogs
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
// Ops list can be larger than standard admin tables.
if
pageSize
>
500
{
pageSize
=
500
}
startTime
,
endTime
,
err
:=
parseOpsTimeRange
(
c
,
"1h"
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
filter
:=
&
service
.
OpsErrorLogFilter
{
Page
:
page
,
PageSize
:
pageSize
}
if
!
startTime
.
IsZero
()
{
filter
.
StartTime
=
&
startTime
}
if
!
endTime
.
IsZero
()
{
filter
.
EndTime
=
&
endTime
}
filter
.
View
=
parseOpsViewParam
(
c
)
filter
.
Phase
=
strings
.
TrimSpace
(
c
.
Query
(
"phase"
))
filter
.
Owner
=
strings
.
TrimSpace
(
c
.
Query
(
"error_owner"
))
filter
.
Source
=
strings
.
TrimSpace
(
c
.
Query
(
"error_source"
))
filter
.
Query
=
strings
.
TrimSpace
(
c
.
Query
(
"q"
))
filter
.
UserQuery
=
strings
.
TrimSpace
(
c
.
Query
(
"user_query"
))
// Force request errors: client-visible status >= 400.
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
if
strings
.
EqualFold
(
strings
.
TrimSpace
(
filter
.
Phase
),
"upstream"
)
{
filter
.
Phase
=
""
}
if
platform
:=
strings
.
TrimSpace
(
c
.
Query
(
"platform"
));
platform
!=
""
{
filter
.
Platform
=
platform
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
filter
.
GroupID
=
&
id
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"account_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid account_id"
)
return
}
filter
.
AccountID
=
&
id
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"resolved"
));
v
!=
""
{
switch
strings
.
ToLower
(
v
)
{
case
"1"
,
"true"
,
"yes"
:
b
:=
true
filter
.
Resolved
=
&
b
case
"0"
,
"false"
,
"no"
:
b
:=
false
filter
.
Resolved
=
&
b
default
:
response
.
BadRequest
(
c
,
"Invalid resolved"
)
return
}
}
if
statusCodesStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"status_codes"
));
statusCodesStr
!=
""
{
parts
:=
strings
.
Split
(
statusCodesStr
,
","
)
out
:=
make
([]
int
,
0
,
len
(
parts
))
for
_
,
part
:=
range
parts
{
p
:=
strings
.
TrimSpace
(
part
)
if
p
==
""
{
continue
}
n
,
err
:=
strconv
.
Atoi
(
p
)
if
err
!=
nil
||
n
<
0
{
response
.
BadRequest
(
c
,
"Invalid status_codes"
)
return
}
out
=
append
(
out
,
n
)
}
filter
.
StatusCodes
=
out
}
result
,
err
:=
h
.
opsService
.
GetErrorLogs
(
c
.
Request
.
Context
(),
filter
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Paginated
(
c
,
result
.
Errors
,
int64
(
result
.
Total
),
result
.
Page
,
result
.
PageSize
)
}
// ListRequestErrors lists client-visible request errors.
// GET /api/v1/admin/ops/request-errors
func
(
h
*
OpsHandler
)
ListRequestErrors
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
if
pageSize
>
500
{
pageSize
=
500
}
startTime
,
endTime
,
err
:=
parseOpsTimeRange
(
c
,
"1h"
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
filter
:=
&
service
.
OpsErrorLogFilter
{
Page
:
page
,
PageSize
:
pageSize
}
if
!
startTime
.
IsZero
()
{
filter
.
StartTime
=
&
startTime
}
if
!
endTime
.
IsZero
()
{
filter
.
EndTime
=
&
endTime
}
filter
.
View
=
parseOpsViewParam
(
c
)
filter
.
Phase
=
strings
.
TrimSpace
(
c
.
Query
(
"phase"
))
filter
.
Owner
=
strings
.
TrimSpace
(
c
.
Query
(
"error_owner"
))
filter
.
Source
=
strings
.
TrimSpace
(
c
.
Query
(
"error_source"
))
filter
.
Query
=
strings
.
TrimSpace
(
c
.
Query
(
"q"
))
filter
.
UserQuery
=
strings
.
TrimSpace
(
c
.
Query
(
"user_query"
))
// Force request errors: client-visible status >= 400.
// buildOpsErrorLogsWhere already applies this for non-upstream phase.
if
strings
.
EqualFold
(
strings
.
TrimSpace
(
filter
.
Phase
),
"upstream"
)
{
filter
.
Phase
=
""
}
if
platform
:=
strings
.
TrimSpace
(
c
.
Query
(
"platform"
));
platform
!=
""
{
filter
.
Platform
=
platform
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
filter
.
GroupID
=
&
id
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"account_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid account_id"
)
return
}
filter
.
AccountID
=
&
id
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"resolved"
));
v
!=
""
{
switch
strings
.
ToLower
(
v
)
{
case
"1"
,
"true"
,
"yes"
:
b
:=
true
filter
.
Resolved
=
&
b
case
"0"
,
"false"
,
"no"
:
b
:=
false
filter
.
Resolved
=
&
b
default
:
response
.
BadRequest
(
c
,
"Invalid resolved"
)
return
}
}
if
statusCodesStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"status_codes"
));
statusCodesStr
!=
""
{
parts
:=
strings
.
Split
(
statusCodesStr
,
","
)
out
:=
make
([]
int
,
0
,
len
(
parts
))
for
_
,
part
:=
range
parts
{
p
:=
strings
.
TrimSpace
(
part
)
if
p
==
""
{
continue
}
n
,
err
:=
strconv
.
Atoi
(
p
)
if
err
!=
nil
||
n
<
0
{
response
.
BadRequest
(
c
,
"Invalid status_codes"
)
return
}
out
=
append
(
out
,
n
)
}
filter
.
StatusCodes
=
out
}
result
,
err
:=
h
.
opsService
.
GetErrorLogs
(
c
.
Request
.
Context
(),
filter
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Paginated
(
c
,
result
.
Errors
,
int64
(
result
.
Total
),
result
.
Page
,
result
.
PageSize
)
}
// GetRequestError returns request error detail.
// GET /api/v1/admin/ops/request-errors/:id
func
(
h
*
OpsHandler
)
GetRequestError
(
c
*
gin
.
Context
)
{
// same storage; just proxy to existing detail
h
.
GetErrorLogByID
(
c
)
}
// ListRequestErrorUpstreamErrors lists upstream error logs correlated to a request error.
// GET /api/v1/admin/ops/request-errors/:id/upstream-errors
func
(
h
*
OpsHandler
)
ListRequestErrorUpstreamErrors
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
idStr
:=
strings
.
TrimSpace
(
c
.
Param
(
"id"
))
id
,
err
:=
strconv
.
ParseInt
(
idStr
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid error id"
)
return
}
// Load request error to get correlation keys.
detail
,
err
:=
h
.
opsService
.
GetErrorLogByID
(
c
.
Request
.
Context
(),
id
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
// Correlate by request_id/client_request_id.
requestID
:=
strings
.
TrimSpace
(
detail
.
RequestID
)
clientRequestID
:=
strings
.
TrimSpace
(
detail
.
ClientRequestID
)
if
requestID
==
""
&&
clientRequestID
==
""
{
response
.
Paginated
(
c
,
[]
*
service
.
OpsErrorLog
{},
0
,
1
,
10
)
return
}
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
if
pageSize
>
500
{
pageSize
=
500
}
// Keep correlation window wide enough so linked upstream errors
// are discoverable even when UI defaults to 1h elsewhere.
startTime
,
endTime
,
err
:=
parseOpsTimeRange
(
c
,
"30d"
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
filter
:=
&
service
.
OpsErrorLogFilter
{
Page
:
page
,
PageSize
:
pageSize
}
if
!
startTime
.
IsZero
()
{
filter
.
StartTime
=
&
startTime
}
if
!
endTime
.
IsZero
()
{
filter
.
EndTime
=
&
endTime
}
filter
.
View
=
"all"
filter
.
Phase
=
"upstream"
filter
.
Owner
=
"provider"
filter
.
Source
=
strings
.
TrimSpace
(
c
.
Query
(
"error_source"
))
filter
.
Query
=
strings
.
TrimSpace
(
c
.
Query
(
"q"
))
if
platform
:=
strings
.
TrimSpace
(
c
.
Query
(
"platform"
));
platform
!=
""
{
filter
.
Platform
=
platform
}
// Prefer exact match on request_id; if missing, fall back to client_request_id.
if
requestID
!=
""
{
filter
.
RequestID
=
requestID
}
else
{
filter
.
ClientRequestID
=
clientRequestID
}
result
,
err
:=
h
.
opsService
.
GetErrorLogs
(
c
.
Request
.
Context
(),
filter
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
// If client asks for details, expand each upstream error log to include upstream response fields.
includeDetail
:=
strings
.
TrimSpace
(
c
.
Query
(
"include_detail"
))
if
includeDetail
==
"1"
||
strings
.
EqualFold
(
includeDetail
,
"true"
)
||
strings
.
EqualFold
(
includeDetail
,
"yes"
)
{
details
:=
make
([]
*
service
.
OpsErrorLogDetail
,
0
,
len
(
result
.
Errors
))
for
_
,
item
:=
range
result
.
Errors
{
if
item
==
nil
{
continue
}
d
,
err
:=
h
.
opsService
.
GetErrorLogByID
(
c
.
Request
.
Context
(),
item
.
ID
)
if
err
!=
nil
||
d
==
nil
{
continue
}
details
=
append
(
details
,
d
)
}
response
.
Paginated
(
c
,
details
,
int64
(
result
.
Total
),
result
.
Page
,
result
.
PageSize
)
return
}
response
.
Paginated
(
c
,
result
.
Errors
,
int64
(
result
.
Total
),
result
.
Page
,
result
.
PageSize
)
}
// RetryRequestErrorClient retries the client request based on stored request body.
// POST /api/v1/admin/ops/request-errors/:id/retry-client
func
(
h
*
OpsHandler
)
RetryRequestErrorClient
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
subject
,
ok
:=
middleware
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
||
subject
.
UserID
<=
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"Unauthorized"
)
return
}
idStr
:=
strings
.
TrimSpace
(
c
.
Param
(
"id"
))
id
,
err
:=
strconv
.
ParseInt
(
idStr
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid error id"
)
return
}
result
,
err
:=
h
.
opsService
.
RetryError
(
c
.
Request
.
Context
(),
subject
.
UserID
,
id
,
service
.
OpsRetryModeClient
,
nil
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
result
)
}
// RetryRequestErrorUpstreamEvent retries a specific upstream attempt using captured upstream_request_body.
// POST /api/v1/admin/ops/request-errors/:id/upstream-errors/:idx/retry
func
(
h
*
OpsHandler
)
RetryRequestErrorUpstreamEvent
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
subject
,
ok
:=
middleware
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
||
subject
.
UserID
<=
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"Unauthorized"
)
return
}
idStr
:=
strings
.
TrimSpace
(
c
.
Param
(
"id"
))
id
,
err
:=
strconv
.
ParseInt
(
idStr
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid error id"
)
return
}
idxStr
:=
strings
.
TrimSpace
(
c
.
Param
(
"idx"
))
idx
,
err
:=
strconv
.
Atoi
(
idxStr
)
if
err
!=
nil
||
idx
<
0
{
response
.
BadRequest
(
c
,
"Invalid upstream idx"
)
return
}
result
,
err
:=
h
.
opsService
.
RetryUpstreamEvent
(
c
.
Request
.
Context
(),
subject
.
UserID
,
id
,
idx
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
result
)
}
// ResolveRequestError toggles resolved status.
// PUT /api/v1/admin/ops/request-errors/:id/resolve
func
(
h
*
OpsHandler
)
ResolveRequestError
(
c
*
gin
.
Context
)
{
h
.
UpdateErrorResolution
(
c
)
}
// ListUpstreamErrors lists independent upstream errors.
// GET /api/v1/admin/ops/upstream-errors
func
(
h
*
OpsHandler
)
ListUpstreamErrors
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
if
pageSize
>
500
{
pageSize
=
500
}
startTime
,
endTime
,
err
:=
parseOpsTimeRange
(
c
,
"1h"
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
filter
:=
&
service
.
OpsErrorLogFilter
{
Page
:
page
,
PageSize
:
pageSize
}
if
!
startTime
.
IsZero
()
{
filter
.
StartTime
=
&
startTime
}
if
!
endTime
.
IsZero
()
{
filter
.
EndTime
=
&
endTime
}
filter
.
View
=
parseOpsViewParam
(
c
)
filter
.
Phase
=
"upstream"
filter
.
Owner
=
"provider"
filter
.
Source
=
strings
.
TrimSpace
(
c
.
Query
(
"error_source"
))
filter
.
Query
=
strings
.
TrimSpace
(
c
.
Query
(
"q"
))
if
platform
:=
strings
.
TrimSpace
(
c
.
Query
(
"platform"
));
platform
!=
""
{
filter
.
Platform
=
platform
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
filter
.
GroupID
=
&
id
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"account_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid account_id"
)
return
}
filter
.
AccountID
=
&
id
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"resolved"
));
v
!=
""
{
switch
strings
.
ToLower
(
v
)
{
case
"1"
,
"true"
,
"yes"
:
b
:=
true
filter
.
Resolved
=
&
b
case
"0"
,
"false"
,
"no"
:
b
:=
false
filter
.
Resolved
=
&
b
default
:
response
.
BadRequest
(
c
,
"Invalid resolved"
)
return
}
}
if
statusCodesStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"status_codes"
));
statusCodesStr
!=
""
{
parts
:=
strings
.
Split
(
statusCodesStr
,
","
)
out
:=
make
([]
int
,
0
,
len
(
parts
))
for
_
,
part
:=
range
parts
{
p
:=
strings
.
TrimSpace
(
part
)
if
p
==
""
{
continue
}
n
,
err
:=
strconv
.
Atoi
(
p
)
if
err
!=
nil
||
n
<
0
{
response
.
BadRequest
(
c
,
"Invalid status_codes"
)
return
}
out
=
append
(
out
,
n
)
}
filter
.
StatusCodes
=
out
}
result
,
err
:=
h
.
opsService
.
GetErrorLogs
(
c
.
Request
.
Context
(),
filter
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Paginated
(
c
,
result
.
Errors
,
int64
(
result
.
Total
),
result
.
Page
,
result
.
PageSize
)
}
// GetUpstreamError returns upstream error detail.
// GET /api/v1/admin/ops/upstream-errors/:id
func
(
h
*
OpsHandler
)
GetUpstreamError
(
c
*
gin
.
Context
)
{
h
.
GetErrorLogByID
(
c
)
}
// RetryUpstreamError retries upstream error using the original account_id.
// POST /api/v1/admin/ops/upstream-errors/:id/retry
func
(
h
*
OpsHandler
)
RetryUpstreamError
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
subject
,
ok
:=
middleware
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
||
subject
.
UserID
<=
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"Unauthorized"
)
return
}
idStr
:=
strings
.
TrimSpace
(
c
.
Param
(
"id"
))
id
,
err
:=
strconv
.
ParseInt
(
idStr
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid error id"
)
return
}
result
,
err
:=
h
.
opsService
.
RetryError
(
c
.
Request
.
Context
(),
subject
.
UserID
,
id
,
service
.
OpsRetryModeUpstream
,
nil
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
result
)
}
// ResolveUpstreamError toggles resolved status.
// PUT /api/v1/admin/ops/upstream-errors/:id/resolve
func
(
h
*
OpsHandler
)
ResolveUpstreamError
(
c
*
gin
.
Context
)
{
h
.
UpdateErrorResolution
(
c
)
}
// ==================== Existing endpoints ====================
// ListRequestDetails returns a request-level list (success + error) for drill-down.
// GET /api/v1/admin/ops/requests
func
(
h
*
OpsHandler
)
ListRequestDetails
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
if
pageSize
>
100
{
pageSize
=
100
}
startTime
,
endTime
,
err
:=
parseOpsTimeRange
(
c
,
"1h"
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
filter
:=
&
service
.
OpsRequestDetailFilter
{
Page
:
page
,
PageSize
:
pageSize
,
StartTime
:
&
startTime
,
EndTime
:
&
endTime
,
}
filter
.
Kind
=
strings
.
TrimSpace
(
c
.
Query
(
"kind"
))
filter
.
Platform
=
strings
.
TrimSpace
(
c
.
Query
(
"platform"
))
filter
.
Model
=
strings
.
TrimSpace
(
c
.
Query
(
"model"
))
filter
.
RequestID
=
strings
.
TrimSpace
(
c
.
Query
(
"request_id"
))
filter
.
Query
=
strings
.
TrimSpace
(
c
.
Query
(
"q"
))
filter
.
Sort
=
strings
.
TrimSpace
(
c
.
Query
(
"sort"
))
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"user_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid user_id"
)
return
}
filter
.
UserID
=
&
id
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"api_key_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid api_key_id"
)
return
}
filter
.
APIKeyID
=
&
id
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"account_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid account_id"
)
return
}
filter
.
AccountID
=
&
id
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
filter
.
GroupID
=
&
id
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"min_duration_ms"
));
v
!=
""
{
parsed
,
err
:=
strconv
.
Atoi
(
v
)
if
err
!=
nil
||
parsed
<
0
{
response
.
BadRequest
(
c
,
"Invalid min_duration_ms"
)
return
}
filter
.
MinDurationMs
=
&
parsed
}
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"max_duration_ms"
));
v
!=
""
{
parsed
,
err
:=
strconv
.
Atoi
(
v
)
if
err
!=
nil
||
parsed
<
0
{
response
.
BadRequest
(
c
,
"Invalid max_duration_ms"
)
return
}
filter
.
MaxDurationMs
=
&
parsed
}
out
,
err
:=
h
.
opsService
.
ListRequestDetails
(
c
.
Request
.
Context
(),
filter
)
if
err
!=
nil
{
// Invalid sort/kind/platform etc should be a bad request; keep it simple.
if
strings
.
Contains
(
strings
.
ToLower
(
err
.
Error
()),
"invalid"
)
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
response
.
Error
(
c
,
http
.
StatusInternalServerError
,
"Failed to list request details"
)
return
}
response
.
Paginated
(
c
,
out
.
Items
,
out
.
Total
,
out
.
Page
,
out
.
PageSize
)
}
type
opsRetryRequest
struct
{
Mode
string
`json:"mode"`
PinnedAccountID
*
int64
`json:"pinned_account_id"`
Force
bool
`json:"force"`
}
type
opsResolveRequest
struct
{
Resolved
bool
`json:"resolved"`
}
// RetryErrorRequest retries a failed request using stored request_body.
// POST /api/v1/admin/ops/errors/:id/retry
func
(
h
*
OpsHandler
)
RetryErrorRequest
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
subject
,
ok
:=
middleware
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
||
subject
.
UserID
<=
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"Unauthorized"
)
return
}
idStr
:=
strings
.
TrimSpace
(
c
.
Param
(
"id"
))
id
,
err
:=
strconv
.
ParseInt
(
idStr
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid error id"
)
return
}
req
:=
opsRetryRequest
{
Mode
:
service
.
OpsRetryModeClient
}
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
&&
!
errors
.
Is
(
err
,
io
.
EOF
)
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
strings
.
TrimSpace
(
req
.
Mode
)
==
""
{
req
.
Mode
=
service
.
OpsRetryModeClient
}
// Force flag is currently a UI-level acknowledgement. Server may still enforce safety constraints.
_
=
req
.
Force
// Legacy endpoint safety: only allow retrying the client request here.
// Upstream retries must go through the split endpoints.
if
strings
.
EqualFold
(
strings
.
TrimSpace
(
req
.
Mode
),
service
.
OpsRetryModeUpstream
)
{
response
.
BadRequest
(
c
,
"upstream retry is not supported on this endpoint"
)
return
}
result
,
err
:=
h
.
opsService
.
RetryError
(
c
.
Request
.
Context
(),
subject
.
UserID
,
id
,
req
.
Mode
,
req
.
PinnedAccountID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
result
)
}
// ListRetryAttempts lists retry attempts for an error log.
// GET /api/v1/admin/ops/errors/:id/retries
func
(
h
*
OpsHandler
)
ListRetryAttempts
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
idStr
:=
strings
.
TrimSpace
(
c
.
Param
(
"id"
))
id
,
err
:=
strconv
.
ParseInt
(
idStr
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid error id"
)
return
}
limit
:=
50
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"limit"
));
v
!=
""
{
n
,
err
:=
strconv
.
Atoi
(
v
)
if
err
!=
nil
||
n
<=
0
{
response
.
BadRequest
(
c
,
"Invalid limit"
)
return
}
limit
=
n
}
items
,
err
:=
h
.
opsService
.
ListRetryAttemptsByErrorID
(
c
.
Request
.
Context
(),
id
,
limit
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
items
)
}
// UpdateErrorResolution allows manual resolve/unresolve.
// PUT /api/v1/admin/ops/errors/:id/resolve
func
(
h
*
OpsHandler
)
UpdateErrorResolution
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
subject
,
ok
:=
middleware
.
GetAuthSubjectFromContext
(
c
)
if
!
ok
||
subject
.
UserID
<=
0
{
response
.
Error
(
c
,
http
.
StatusUnauthorized
,
"Unauthorized"
)
return
}
idStr
:=
strings
.
TrimSpace
(
c
.
Param
(
"id"
))
id
,
err
:=
strconv
.
ParseInt
(
idStr
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid error id"
)
return
}
var
req
opsResolveRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
uid
:=
subject
.
UserID
if
err
:=
h
.
opsService
.
UpdateErrorResolution
(
c
.
Request
.
Context
(),
id
,
req
.
Resolved
,
&
uid
,
nil
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"ok"
:
true
})
}
func
parseOpsTimeRange
(
c
*
gin
.
Context
,
defaultRange
string
)
(
time
.
Time
,
time
.
Time
,
error
)
{
startStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"start_time"
))
endStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"end_time"
))
parseTS
:=
func
(
s
string
)
(
time
.
Time
,
error
)
{
if
s
==
""
{
return
time
.
Time
{},
nil
}
if
t
,
err
:=
time
.
Parse
(
time
.
RFC3339Nano
,
s
);
err
==
nil
{
return
t
,
nil
}
return
time
.
Parse
(
time
.
RFC3339
,
s
)
}
start
,
err
:=
parseTS
(
startStr
)
if
err
!=
nil
{
return
time
.
Time
{},
time
.
Time
{},
err
}
end
,
err
:=
parseTS
(
endStr
)
if
err
!=
nil
{
return
time
.
Time
{},
time
.
Time
{},
err
}
// start/end explicitly provided (even partially)
if
startStr
!=
""
||
endStr
!=
""
{
if
end
.
IsZero
()
{
end
=
time
.
Now
()
}
if
start
.
IsZero
()
{
dur
,
_
:=
parseOpsDuration
(
defaultRange
)
start
=
end
.
Add
(
-
dur
)
}
if
start
.
After
(
end
)
{
return
time
.
Time
{},
time
.
Time
{},
fmt
.
Errorf
(
"invalid time range: start_time must be <= end_time"
)
}
if
end
.
Sub
(
start
)
>
30
*
24
*
time
.
Hour
{
return
time
.
Time
{},
time
.
Time
{},
fmt
.
Errorf
(
"invalid time range: max window is 30 days"
)
}
return
start
,
end
,
nil
}
// time_range fallback
tr
:=
strings
.
TrimSpace
(
c
.
Query
(
"time_range"
))
if
tr
==
""
{
tr
=
defaultRange
}
dur
,
ok
:=
parseOpsDuration
(
tr
)
if
!
ok
{
dur
,
_
=
parseOpsDuration
(
defaultRange
)
}
end
=
time
.
Now
()
start
=
end
.
Add
(
-
dur
)
if
end
.
Sub
(
start
)
>
30
*
24
*
time
.
Hour
{
return
time
.
Time
{},
time
.
Time
{},
fmt
.
Errorf
(
"invalid time range: max window is 30 days"
)
}
return
start
,
end
,
nil
}
func
parseOpsDuration
(
v
string
)
(
time
.
Duration
,
bool
)
{
switch
strings
.
TrimSpace
(
v
)
{
case
"5m"
:
return
5
*
time
.
Minute
,
true
case
"30m"
:
return
30
*
time
.
Minute
,
true
case
"1h"
:
return
time
.
Hour
,
true
case
"6h"
:
return
6
*
time
.
Hour
,
true
case
"24h"
:
return
24
*
time
.
Hour
,
true
case
"7d"
:
return
7
*
24
*
time
.
Hour
,
true
case
"30d"
:
return
30
*
24
*
time
.
Hour
,
true
default
:
return
0
,
false
}
}
backend/internal/handler/admin/ops_realtime_handler.go
0 → 100644
View file @
b9b4db3d
package
admin
import
(
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GetConcurrencyStats returns real-time concurrency usage aggregated by platform/group/account.
// GET /api/v1/admin/ops/concurrency
func
(
h
*
OpsHandler
)
GetConcurrencyStats
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
if
!
h
.
opsService
.
IsRealtimeMonitoringEnabled
(
c
.
Request
.
Context
())
{
response
.
Success
(
c
,
gin
.
H
{
"enabled"
:
false
,
"platform"
:
map
[
string
]
*
service
.
PlatformConcurrencyInfo
{},
"group"
:
map
[
int64
]
*
service
.
GroupConcurrencyInfo
{},
"account"
:
map
[
int64
]
*
service
.
AccountConcurrencyInfo
{},
"timestamp"
:
time
.
Now
()
.
UTC
(),
})
return
}
platformFilter
:=
strings
.
TrimSpace
(
c
.
Query
(
"platform"
))
var
groupID
*
int64
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
groupID
=
&
id
}
platform
,
group
,
account
,
collectedAt
,
err
:=
h
.
opsService
.
GetConcurrencyStats
(
c
.
Request
.
Context
(),
platformFilter
,
groupID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
payload
:=
gin
.
H
{
"enabled"
:
true
,
"platform"
:
platform
,
"group"
:
group
,
"account"
:
account
,
}
if
collectedAt
!=
nil
{
payload
[
"timestamp"
]
=
collectedAt
.
UTC
()
}
response
.
Success
(
c
,
payload
)
}
// GetAccountAvailability returns account availability statistics.
// GET /api/v1/admin/ops/account-availability
//
// Query params:
// - platform: optional
// - group_id: optional
func
(
h
*
OpsHandler
)
GetAccountAvailability
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
if
!
h
.
opsService
.
IsRealtimeMonitoringEnabled
(
c
.
Request
.
Context
())
{
response
.
Success
(
c
,
gin
.
H
{
"enabled"
:
false
,
"platform"
:
map
[
string
]
*
service
.
PlatformAvailability
{},
"group"
:
map
[
int64
]
*
service
.
GroupAvailability
{},
"account"
:
map
[
int64
]
*
service
.
AccountAvailability
{},
"timestamp"
:
time
.
Now
()
.
UTC
(),
})
return
}
platform
:=
strings
.
TrimSpace
(
c
.
Query
(
"platform"
))
var
groupID
*
int64
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
groupID
=
&
id
}
platformStats
,
groupStats
,
accountStats
,
collectedAt
,
err
:=
h
.
opsService
.
GetAccountAvailabilityStats
(
c
.
Request
.
Context
(),
platform
,
groupID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
payload
:=
gin
.
H
{
"enabled"
:
true
,
"platform"
:
platformStats
,
"group"
:
groupStats
,
"account"
:
accountStats
,
}
if
collectedAt
!=
nil
{
payload
[
"timestamp"
]
=
collectedAt
.
UTC
()
}
response
.
Success
(
c
,
payload
)
}
func
parseOpsRealtimeWindow
(
v
string
)
(
time
.
Duration
,
string
,
bool
)
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
v
))
{
case
""
,
"1min"
,
"1m"
:
return
1
*
time
.
Minute
,
"1min"
,
true
case
"5min"
,
"5m"
:
return
5
*
time
.
Minute
,
"5min"
,
true
case
"30min"
,
"30m"
:
return
30
*
time
.
Minute
,
"30min"
,
true
case
"1h"
,
"60m"
,
"60min"
:
return
1
*
time
.
Hour
,
"1h"
,
true
default
:
return
0
,
""
,
false
}
}
// GetRealtimeTrafficSummary returns QPS/TPS current/peak/avg for the selected window.
// GET /api/v1/admin/ops/realtime-traffic
//
// Query params:
// - window: 1min|5min|30min|1h (default: 1min)
// - platform: optional
// - group_id: optional
func
(
h
*
OpsHandler
)
GetRealtimeTrafficSummary
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
windowDur
,
windowLabel
,
ok
:=
parseOpsRealtimeWindow
(
c
.
Query
(
"window"
))
if
!
ok
{
response
.
BadRequest
(
c
,
"Invalid window"
)
return
}
platform
:=
strings
.
TrimSpace
(
c
.
Query
(
"platform"
))
var
groupID
*
int64
if
v
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
v
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
v
,
10
,
64
)
if
err
!=
nil
||
id
<=
0
{
response
.
BadRequest
(
c
,
"Invalid group_id"
)
return
}
groupID
=
&
id
}
endTime
:=
time
.
Now
()
.
UTC
()
startTime
:=
endTime
.
Add
(
-
windowDur
)
if
!
h
.
opsService
.
IsRealtimeMonitoringEnabled
(
c
.
Request
.
Context
())
{
disabledSummary
:=
&
service
.
OpsRealtimeTrafficSummary
{
Window
:
windowLabel
,
StartTime
:
startTime
,
EndTime
:
endTime
,
Platform
:
platform
,
GroupID
:
groupID
,
QPS
:
service
.
OpsRateSummary
{},
TPS
:
service
.
OpsRateSummary
{},
}
response
.
Success
(
c
,
gin
.
H
{
"enabled"
:
false
,
"summary"
:
disabledSummary
,
"timestamp"
:
endTime
,
})
return
}
filter
:=
&
service
.
OpsDashboardFilter
{
StartTime
:
startTime
,
EndTime
:
endTime
,
Platform
:
platform
,
GroupID
:
groupID
,
QueryMode
:
service
.
OpsQueryModeRaw
,
}
summary
,
err
:=
h
.
opsService
.
GetRealtimeTrafficSummary
(
c
.
Request
.
Context
(),
filter
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
if
summary
!=
nil
{
summary
.
Window
=
windowLabel
}
response
.
Success
(
c
,
gin
.
H
{
"enabled"
:
true
,
"summary"
:
summary
,
"timestamp"
:
endTime
,
})
}
backend/internal/handler/admin/ops_settings_handler.go
0 → 100644
View file @
b9b4db3d
package
admin
import
(
"net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GetEmailNotificationConfig returns Ops email notification config (DB-backed).
// GET /api/v1/admin/ops/email-notification/config
func
(
h
*
OpsHandler
)
GetEmailNotificationConfig
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
cfg
,
err
:=
h
.
opsService
.
GetEmailNotificationConfig
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusInternalServerError
,
"Failed to get email notification config"
)
return
}
response
.
Success
(
c
,
cfg
)
}
// UpdateEmailNotificationConfig updates Ops email notification config (DB-backed).
// PUT /api/v1/admin/ops/email-notification/config
func
(
h
*
OpsHandler
)
UpdateEmailNotificationConfig
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
var
req
service
.
OpsEmailNotificationConfigUpdateRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request body"
)
return
}
updated
,
err
:=
h
.
opsService
.
UpdateEmailNotificationConfig
(
c
.
Request
.
Context
(),
&
req
)
if
err
!=
nil
{
// Most failures here are validation errors from request payload; treat as 400.
response
.
Error
(
c
,
http
.
StatusBadRequest
,
err
.
Error
())
return
}
response
.
Success
(
c
,
updated
)
}
// GetAlertRuntimeSettings returns Ops alert evaluator runtime settings (DB-backed).
// GET /api/v1/admin/ops/runtime/alert
func
(
h
*
OpsHandler
)
GetAlertRuntimeSettings
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
cfg
,
err
:=
h
.
opsService
.
GetOpsAlertRuntimeSettings
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusInternalServerError
,
"Failed to get alert runtime settings"
)
return
}
response
.
Success
(
c
,
cfg
)
}
// UpdateAlertRuntimeSettings updates Ops alert evaluator runtime settings (DB-backed).
// PUT /api/v1/admin/ops/runtime/alert
func
(
h
*
OpsHandler
)
UpdateAlertRuntimeSettings
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
var
req
service
.
OpsAlertRuntimeSettings
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request body"
)
return
}
updated
,
err
:=
h
.
opsService
.
UpdateOpsAlertRuntimeSettings
(
c
.
Request
.
Context
(),
&
req
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
err
.
Error
())
return
}
response
.
Success
(
c
,
updated
)
}
// GetAdvancedSettings returns Ops advanced settings (DB-backed).
// GET /api/v1/admin/ops/advanced-settings
func
(
h
*
OpsHandler
)
GetAdvancedSettings
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
cfg
,
err
:=
h
.
opsService
.
GetOpsAdvancedSettings
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusInternalServerError
,
"Failed to get advanced settings"
)
return
}
response
.
Success
(
c
,
cfg
)
}
// UpdateAdvancedSettings updates Ops advanced settings (DB-backed).
// PUT /api/v1/admin/ops/advanced-settings
func
(
h
*
OpsHandler
)
UpdateAdvancedSettings
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
var
req
service
.
OpsAdvancedSettings
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request body"
)
return
}
updated
,
err
:=
h
.
opsService
.
UpdateOpsAdvancedSettings
(
c
.
Request
.
Context
(),
&
req
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
err
.
Error
())
return
}
response
.
Success
(
c
,
updated
)
}
// GetMetricThresholds returns Ops metric thresholds (DB-backed).
// GET /api/v1/admin/ops/settings/metric-thresholds
func
(
h
*
OpsHandler
)
GetMetricThresholds
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
cfg
,
err
:=
h
.
opsService
.
GetMetricThresholds
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusInternalServerError
,
"Failed to get metric thresholds"
)
return
}
response
.
Success
(
c
,
cfg
)
}
// UpdateMetricThresholds updates Ops metric thresholds (DB-backed).
// PUT /api/v1/admin/ops/settings/metric-thresholds
func
(
h
*
OpsHandler
)
UpdateMetricThresholds
(
c
*
gin
.
Context
)
{
if
h
.
opsService
==
nil
{
response
.
Error
(
c
,
http
.
StatusServiceUnavailable
,
"Ops service not available"
)
return
}
if
err
:=
h
.
opsService
.
RequireMonitoringEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
var
req
service
.
OpsMetricThresholds
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request body"
)
return
}
updated
,
err
:=
h
.
opsService
.
UpdateMetricThresholds
(
c
.
Request
.
Context
(),
&
req
)
if
err
!=
nil
{
response
.
Error
(
c
,
http
.
StatusBadRequest
,
err
.
Error
())
return
}
response
.
Success
(
c
,
updated
)
}
backend/internal/handler/admin/ops_ws_handler.go
0 → 100644
View file @
b9b4db3d
package
admin
import
(
"context"
"encoding/json"
"log"
"math"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
type
OpsWSProxyConfig
struct
{
TrustProxy
bool
TrustedProxies
[]
netip
.
Prefix
OriginPolicy
string
}
const
(
envOpsWSTrustProxy
=
"OPS_WS_TRUST_PROXY"
envOpsWSTrustedProxies
=
"OPS_WS_TRUSTED_PROXIES"
envOpsWSOriginPolicy
=
"OPS_WS_ORIGIN_POLICY"
envOpsWSMaxConns
=
"OPS_WS_MAX_CONNS"
envOpsWSMaxConnsPerIP
=
"OPS_WS_MAX_CONNS_PER_IP"
)
const
(
OriginPolicyStrict
=
"strict"
OriginPolicyPermissive
=
"permissive"
)
var
opsWSProxyConfig
=
loadOpsWSProxyConfigFromEnv
()
var
upgrader
=
websocket
.
Upgrader
{
CheckOrigin
:
func
(
r
*
http
.
Request
)
bool
{
return
isAllowedOpsWSOrigin
(
r
)
},
// Subprotocol negotiation:
// - The frontend passes ["sub2api-admin", "jwt.<token>"].
// - We always select "sub2api-admin" so the token is never echoed back in the handshake response.
Subprotocols
:
[]
string
{
"sub2api-admin"
},
}
const
(
qpsWSPushInterval
=
2
*
time
.
Second
qpsWSRefreshInterval
=
5
*
time
.
Second
qpsWSRequestCountWindow
=
1
*
time
.
Minute
defaultMaxWSConns
=
100
defaultMaxWSConnsPerIP
=
20
)
var
wsConnCount
atomic
.
Int32
var
wsConnCountByIP
sync
.
Map
// map[string]*atomic.Int32
const
qpsWSIdleStopDelay
=
30
*
time
.
Second
const
(
opsWSCloseRealtimeDisabled
=
4001
)
var
qpsWSIdleStopMu
sync
.
Mutex
var
qpsWSIdleStopTimer
*
time
.
Timer
func
cancelQPSWSIdleStop
()
{
qpsWSIdleStopMu
.
Lock
()
if
qpsWSIdleStopTimer
!=
nil
{
qpsWSIdleStopTimer
.
Stop
()
qpsWSIdleStopTimer
=
nil
}
qpsWSIdleStopMu
.
Unlock
()
}
func
scheduleQPSWSIdleStop
()
{
qpsWSIdleStopMu
.
Lock
()
if
qpsWSIdleStopTimer
!=
nil
{
qpsWSIdleStopMu
.
Unlock
()
return
}
qpsWSIdleStopTimer
=
time
.
AfterFunc
(
qpsWSIdleStopDelay
,
func
()
{
// Only stop if truly idle at fire time.
if
wsConnCount
.
Load
()
==
0
{
qpsWSCache
.
Stop
()
}
qpsWSIdleStopMu
.
Lock
()
qpsWSIdleStopTimer
=
nil
qpsWSIdleStopMu
.
Unlock
()
})
qpsWSIdleStopMu
.
Unlock
()
}
type
opsWSRuntimeLimits
struct
{
MaxConns
int32
MaxConnsPerIP
int32
}
var
opsWSLimits
=
loadOpsWSRuntimeLimitsFromEnv
()
const
(
qpsWSWriteTimeout
=
10
*
time
.
Second
qpsWSPongWait
=
60
*
time
.
Second
qpsWSPingInterval
=
30
*
time
.
Second
// We don't expect clients to send application messages; we only read to process control frames (Pong/Close).
qpsWSMaxReadBytes
=
1024
)
type
opsWSQPSCache
struct
{
refreshInterval
time
.
Duration
requestCountWindow
time
.
Duration
lastUpdatedUnixNano
atomic
.
Int64
payload
atomic
.
Value
// []byte
opsService
*
service
.
OpsService
cancel
context
.
CancelFunc
done
chan
struct
{}
mu
sync
.
Mutex
running
bool
}
var
qpsWSCache
=
&
opsWSQPSCache
{
refreshInterval
:
qpsWSRefreshInterval
,
requestCountWindow
:
qpsWSRequestCountWindow
,
}
func
(
c
*
opsWSQPSCache
)
start
(
opsService
*
service
.
OpsService
)
{
if
c
==
nil
||
opsService
==
nil
{
return
}
for
{
c
.
mu
.
Lock
()
if
c
.
running
{
c
.
mu
.
Unlock
()
return
}
// If a previous refresh loop is currently stopping, wait for it to fully exit.
done
:=
c
.
done
if
done
!=
nil
{
c
.
mu
.
Unlock
()
<-
done
c
.
mu
.
Lock
()
if
c
.
done
==
done
&&
!
c
.
running
{
c
.
done
=
nil
}
c
.
mu
.
Unlock
()
continue
}
c
.
opsService
=
opsService
ctx
,
cancel
:=
context
.
WithCancel
(
context
.
Background
())
c
.
cancel
=
cancel
c
.
done
=
make
(
chan
struct
{})
done
=
c
.
done
c
.
running
=
true
c
.
mu
.
Unlock
()
go
func
()
{
defer
close
(
done
)
c
.
refreshLoop
(
ctx
)
}()
return
}
}
// Stop stops the background refresh loop.
// It is safe to call multiple times.
func
(
c
*
opsWSQPSCache
)
Stop
()
{
if
c
==
nil
{
return
}
c
.
mu
.
Lock
()
if
!
c
.
running
{
done
:=
c
.
done
c
.
mu
.
Unlock
()
if
done
!=
nil
{
<-
done
}
return
}
cancel
:=
c
.
cancel
c
.
cancel
=
nil
c
.
running
=
false
c
.
opsService
=
nil
done
:=
c
.
done
c
.
mu
.
Unlock
()
if
cancel
!=
nil
{
cancel
()
}
if
done
!=
nil
{
<-
done
}
c
.
mu
.
Lock
()
if
c
.
done
==
done
&&
!
c
.
running
{
c
.
done
=
nil
}
c
.
mu
.
Unlock
()
}
func
(
c
*
opsWSQPSCache
)
refreshLoop
(
ctx
context
.
Context
)
{
ticker
:=
time
.
NewTicker
(
c
.
refreshInterval
)
defer
ticker
.
Stop
()
c
.
refresh
(
ctx
)
for
{
select
{
case
<-
ticker
.
C
:
c
.
refresh
(
ctx
)
case
<-
ctx
.
Done
()
:
return
}
}
}
func
(
c
*
opsWSQPSCache
)
refresh
(
parentCtx
context
.
Context
)
{
if
c
==
nil
{
return
}
c
.
mu
.
Lock
()
opsService
:=
c
.
opsService
c
.
mu
.
Unlock
()
if
opsService
==
nil
{
return
}
if
parentCtx
==
nil
{
parentCtx
=
context
.
Background
()
}
ctx
,
cancel
:=
context
.
WithTimeout
(
parentCtx
,
10
*
time
.
Second
)
defer
cancel
()
now
:=
time
.
Now
()
.
UTC
()
stats
,
err
:=
opsService
.
GetWindowStats
(
ctx
,
now
.
Add
(
-
c
.
requestCountWindow
),
now
)
if
err
!=
nil
||
stats
==
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[OpsWS] refresh: get window stats failed: %v"
,
err
)
}
return
}
requestCount
:=
stats
.
SuccessCount
+
stats
.
ErrorCountTotal
qps
:=
0.0
tps
:=
0.0
if
c
.
requestCountWindow
>
0
{
seconds
:=
c
.
requestCountWindow
.
Seconds
()
qps
=
roundTo1DP
(
float64
(
requestCount
)
/
seconds
)
tps
=
roundTo1DP
(
float64
(
stats
.
TokenConsumed
)
/
seconds
)
}
payload
:=
gin
.
H
{
"type"
:
"qps_update"
,
"timestamp"
:
now
.
Format
(
time
.
RFC3339
),
"data"
:
gin
.
H
{
"qps"
:
qps
,
"tps"
:
tps
,
"request_count"
:
requestCount
,
},
}
msg
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
log
.
Printf
(
"[OpsWS] refresh: marshal payload failed: %v"
,
err
)
return
}
c
.
payload
.
Store
(
msg
)
c
.
lastUpdatedUnixNano
.
Store
(
now
.
UnixNano
())
}
func
roundTo1DP
(
v
float64
)
float64
{
return
math
.
Round
(
v
*
10
)
/
10
}
func
(
c
*
opsWSQPSCache
)
getPayload
()
[]
byte
{
if
c
==
nil
{
return
nil
}
if
cached
,
ok
:=
c
.
payload
.
Load
()
.
([]
byte
);
ok
&&
cached
!=
nil
{
return
cached
}
return
nil
}
func
closeWS
(
conn
*
websocket
.
Conn
,
code
int
,
reason
string
)
{
if
conn
==
nil
{
return
}
msg
:=
websocket
.
FormatCloseMessage
(
code
,
reason
)
_
=
conn
.
WriteControl
(
websocket
.
CloseMessage
,
msg
,
time
.
Now
()
.
Add
(
qpsWSWriteTimeout
))
_
=
conn
.
Close
()
}
// QPSWSHandler handles realtime QPS push via WebSocket.
// GET /api/v1/admin/ops/ws/qps
func
(
h
*
OpsHandler
)
QPSWSHandler
(
c
*
gin
.
Context
)
{
clientIP
:=
requestClientIP
(
c
.
Request
)
if
h
==
nil
||
h
.
opsService
==
nil
{
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
"ops service not initialized"
})
return
}
// If realtime monitoring is disabled, prefer a successful WS upgrade followed by a clean close
// with a deterministic close code. This prevents clients from spinning on 404/1006 reconnect loops.
if
!
h
.
opsService
.
IsRealtimeMonitoringEnabled
(
c
.
Request
.
Context
())
{
conn
,
err
:=
upgrader
.
Upgrade
(
c
.
Writer
,
c
.
Request
,
nil
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
"ops realtime monitoring is disabled"
})
return
}
closeWS
(
conn
,
opsWSCloseRealtimeDisabled
,
"realtime_disabled"
)
return
}
cancelQPSWSIdleStop
()
// Lazily start the background refresh loop so unit tests that never hit the
// websocket route don't spawn goroutines that depend on DB/Redis stubs.
qpsWSCache
.
start
(
h
.
opsService
)
// Reserve a global slot before upgrading the connection to keep the limit strict.
if
!
tryAcquireOpsWSTotalSlot
(
opsWSLimits
.
MaxConns
)
{
log
.
Printf
(
"[OpsWS] connection limit reached: %d/%d"
,
wsConnCount
.
Load
(),
opsWSLimits
.
MaxConns
)
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
"too many connections"
})
return
}
defer
func
()
{
if
wsConnCount
.
Add
(
-
1
)
==
0
{
scheduleQPSWSIdleStop
()
}
}()
if
opsWSLimits
.
MaxConnsPerIP
>
0
&&
clientIP
!=
""
{
if
!
tryAcquireOpsWSIPSlot
(
clientIP
,
opsWSLimits
.
MaxConnsPerIP
)
{
log
.
Printf
(
"[OpsWS] per-ip connection limit reached: ip=%s limit=%d"
,
clientIP
,
opsWSLimits
.
MaxConnsPerIP
)
c
.
JSON
(
http
.
StatusServiceUnavailable
,
gin
.
H
{
"error"
:
"too many connections"
})
return
}
defer
releaseOpsWSIPSlot
(
clientIP
)
}
conn
,
err
:=
upgrader
.
Upgrade
(
c
.
Writer
,
c
.
Request
,
nil
)
if
err
!=
nil
{
log
.
Printf
(
"[OpsWS] upgrade failed: %v"
,
err
)
return
}
defer
func
()
{
_
=
conn
.
Close
()
}()
handleQPSWebSocket
(
c
.
Request
.
Context
(),
conn
)
}
func
tryAcquireOpsWSTotalSlot
(
limit
int32
)
bool
{
if
limit
<=
0
{
return
true
}
for
{
current
:=
wsConnCount
.
Load
()
if
current
>=
limit
{
return
false
}
if
wsConnCount
.
CompareAndSwap
(
current
,
current
+
1
)
{
return
true
}
}
}
func
tryAcquireOpsWSIPSlot
(
clientIP
string
,
limit
int32
)
bool
{
if
strings
.
TrimSpace
(
clientIP
)
==
""
||
limit
<=
0
{
return
true
}
v
,
_
:=
wsConnCountByIP
.
LoadOrStore
(
clientIP
,
&
atomic
.
Int32
{})
counter
,
ok
:=
v
.
(
*
atomic
.
Int32
)
if
!
ok
{
return
false
}
for
{
current
:=
counter
.
Load
()
if
current
>=
limit
{
return
false
}
if
counter
.
CompareAndSwap
(
current
,
current
+
1
)
{
return
true
}
}
}
func
releaseOpsWSIPSlot
(
clientIP
string
)
{
if
strings
.
TrimSpace
(
clientIP
)
==
""
{
return
}
v
,
ok
:=
wsConnCountByIP
.
Load
(
clientIP
)
if
!
ok
{
return
}
counter
,
ok
:=
v
.
(
*
atomic
.
Int32
)
if
!
ok
{
return
}
next
:=
counter
.
Add
(
-
1
)
if
next
<=
0
{
// Best-effort cleanup; safe even if a new slot was acquired concurrently.
wsConnCountByIP
.
Delete
(
clientIP
)
}
}
func
handleQPSWebSocket
(
parentCtx
context
.
Context
,
conn
*
websocket
.
Conn
)
{
if
conn
==
nil
{
return
}
ctx
,
cancel
:=
context
.
WithCancel
(
parentCtx
)
defer
cancel
()
var
closeOnce
sync
.
Once
closeConn
:=
func
()
{
closeOnce
.
Do
(
func
()
{
_
=
conn
.
Close
()
})
}
closeFrameCh
:=
make
(
chan
[]
byte
,
1
)
var
wg
sync
.
WaitGroup
wg
.
Add
(
1
)
go
func
()
{
defer
wg
.
Done
()
defer
cancel
()
conn
.
SetReadLimit
(
qpsWSMaxReadBytes
)
if
err
:=
conn
.
SetReadDeadline
(
time
.
Now
()
.
Add
(
qpsWSPongWait
));
err
!=
nil
{
log
.
Printf
(
"[OpsWS] set read deadline failed: %v"
,
err
)
return
}
conn
.
SetPongHandler
(
func
(
string
)
error
{
return
conn
.
SetReadDeadline
(
time
.
Now
()
.
Add
(
qpsWSPongWait
))
})
conn
.
SetCloseHandler
(
func
(
code
int
,
text
string
)
error
{
select
{
case
closeFrameCh
<-
websocket
.
FormatCloseMessage
(
code
,
text
)
:
default
:
}
cancel
()
return
nil
})
for
{
_
,
_
,
err
:=
conn
.
ReadMessage
()
if
err
!=
nil
{
if
websocket
.
IsUnexpectedCloseError
(
err
,
websocket
.
CloseNormalClosure
,
websocket
.
CloseGoingAway
,
websocket
.
CloseNoStatusReceived
)
{
log
.
Printf
(
"[OpsWS] read failed: %v"
,
err
)
}
return
}
}
}()
// Push QPS data every 2 seconds (values are globally cached and refreshed at most once per qpsWSRefreshInterval).
pushTicker
:=
time
.
NewTicker
(
qpsWSPushInterval
)
defer
pushTicker
.
Stop
()
// Heartbeat ping every 30 seconds.
pingTicker
:=
time
.
NewTicker
(
qpsWSPingInterval
)
defer
pingTicker
.
Stop
()
writeWithTimeout
:=
func
(
messageType
int
,
data
[]
byte
)
error
{
if
err
:=
conn
.
SetWriteDeadline
(
time
.
Now
()
.
Add
(
qpsWSWriteTimeout
));
err
!=
nil
{
return
err
}
return
conn
.
WriteMessage
(
messageType
,
data
)
}
sendClose
:=
func
(
closeFrame
[]
byte
)
{
if
closeFrame
==
nil
{
closeFrame
=
websocket
.
FormatCloseMessage
(
websocket
.
CloseNormalClosure
,
""
)
}
_
=
writeWithTimeout
(
websocket
.
CloseMessage
,
closeFrame
)
}
for
{
select
{
case
<-
pushTicker
.
C
:
msg
:=
qpsWSCache
.
getPayload
()
if
msg
==
nil
{
continue
}
if
err
:=
writeWithTimeout
(
websocket
.
TextMessage
,
msg
);
err
!=
nil
{
log
.
Printf
(
"[OpsWS] write failed: %v"
,
err
)
cancel
()
closeConn
()
wg
.
Wait
()
return
}
case
<-
pingTicker
.
C
:
if
err
:=
writeWithTimeout
(
websocket
.
PingMessage
,
nil
);
err
!=
nil
{
log
.
Printf
(
"[OpsWS] ping failed: %v"
,
err
)
cancel
()
closeConn
()
wg
.
Wait
()
return
}
case
closeFrame
:=
<-
closeFrameCh
:
sendClose
(
closeFrame
)
closeConn
()
wg
.
Wait
()
return
case
<-
ctx
.
Done
()
:
var
closeFrame
[]
byte
select
{
case
closeFrame
=
<-
closeFrameCh
:
default
:
}
sendClose
(
closeFrame
)
closeConn
()
wg
.
Wait
()
return
}
}
}
func
isAllowedOpsWSOrigin
(
r
*
http
.
Request
)
bool
{
if
r
==
nil
{
return
false
}
origin
:=
strings
.
TrimSpace
(
r
.
Header
.
Get
(
"Origin"
))
if
origin
==
""
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
opsWSProxyConfig
.
OriginPolicy
))
{
case
OriginPolicyStrict
:
return
false
case
OriginPolicyPermissive
,
""
:
return
true
default
:
return
true
}
}
parsed
,
err
:=
url
.
Parse
(
origin
)
if
err
!=
nil
||
parsed
.
Hostname
()
==
""
{
return
false
}
originHost
:=
strings
.
ToLower
(
parsed
.
Hostname
())
trustProxyHeaders
:=
shouldTrustOpsWSProxyHeaders
(
r
)
reqHost
:=
hostWithoutPort
(
r
.
Host
)
if
trustProxyHeaders
{
xfHost
:=
strings
.
TrimSpace
(
r
.
Header
.
Get
(
"X-Forwarded-Host"
))
if
xfHost
!=
""
{
xfHost
=
strings
.
TrimSpace
(
strings
.
Split
(
xfHost
,
","
)[
0
])
if
xfHost
!=
""
{
reqHost
=
hostWithoutPort
(
xfHost
)
}
}
}
reqHost
=
strings
.
ToLower
(
reqHost
)
if
reqHost
==
""
{
return
false
}
return
originHost
==
reqHost
}
func
shouldTrustOpsWSProxyHeaders
(
r
*
http
.
Request
)
bool
{
if
r
==
nil
{
return
false
}
if
!
opsWSProxyConfig
.
TrustProxy
{
return
false
}
peerIP
,
ok
:=
requestPeerIP
(
r
)
if
!
ok
{
return
false
}
return
isAddrInTrustedProxies
(
peerIP
,
opsWSProxyConfig
.
TrustedProxies
)
}
func
requestPeerIP
(
r
*
http
.
Request
)
(
netip
.
Addr
,
bool
)
{
if
r
==
nil
{
return
netip
.
Addr
{},
false
}
host
,
_
,
err
:=
net
.
SplitHostPort
(
strings
.
TrimSpace
(
r
.
RemoteAddr
))
if
err
!=
nil
{
host
=
strings
.
TrimSpace
(
r
.
RemoteAddr
)
}
host
=
strings
.
TrimPrefix
(
host
,
"["
)
host
=
strings
.
TrimSuffix
(
host
,
"]"
)
if
host
==
""
{
return
netip
.
Addr
{},
false
}
addr
,
err
:=
netip
.
ParseAddr
(
host
)
if
err
!=
nil
{
return
netip
.
Addr
{},
false
}
return
addr
.
Unmap
(),
true
}
func
requestClientIP
(
r
*
http
.
Request
)
string
{
if
r
==
nil
{
return
""
}
trustProxyHeaders
:=
shouldTrustOpsWSProxyHeaders
(
r
)
if
trustProxyHeaders
{
xff
:=
strings
.
TrimSpace
(
r
.
Header
.
Get
(
"X-Forwarded-For"
))
if
xff
!=
""
{
// Use the left-most entry (original client). If multiple proxies add values, they are comma-separated.
xff
=
strings
.
TrimSpace
(
strings
.
Split
(
xff
,
","
)[
0
])
xff
=
strings
.
TrimPrefix
(
xff
,
"["
)
xff
=
strings
.
TrimSuffix
(
xff
,
"]"
)
if
addr
,
err
:=
netip
.
ParseAddr
(
xff
);
err
==
nil
&&
addr
.
IsValid
()
{
return
addr
.
Unmap
()
.
String
()
}
}
}
if
peer
,
ok
:=
requestPeerIP
(
r
);
ok
&&
peer
.
IsValid
()
{
return
peer
.
String
()
}
return
""
}
func
isAddrInTrustedProxies
(
addr
netip
.
Addr
,
trusted
[]
netip
.
Prefix
)
bool
{
if
!
addr
.
IsValid
()
{
return
false
}
for
_
,
p
:=
range
trusted
{
if
p
.
Contains
(
addr
)
{
return
true
}
}
return
false
}
func
loadOpsWSProxyConfigFromEnv
()
OpsWSProxyConfig
{
cfg
:=
OpsWSProxyConfig
{
TrustProxy
:
true
,
TrustedProxies
:
defaultTrustedProxies
(),
OriginPolicy
:
OriginPolicyPermissive
,
}
if
v
:=
strings
.
TrimSpace
(
os
.
Getenv
(
envOpsWSTrustProxy
));
v
!=
""
{
if
parsed
,
err
:=
strconv
.
ParseBool
(
v
);
err
==
nil
{
cfg
.
TrustProxy
=
parsed
}
else
{
log
.
Printf
(
"[OpsWS] invalid %s=%q (expected bool); using default=%v"
,
envOpsWSTrustProxy
,
v
,
cfg
.
TrustProxy
)
}
}
if
raw
:=
strings
.
TrimSpace
(
os
.
Getenv
(
envOpsWSTrustedProxies
));
raw
!=
""
{
prefixes
,
invalid
:=
parseTrustedProxyList
(
raw
)
if
len
(
invalid
)
>
0
{
log
.
Printf
(
"[OpsWS] invalid %s entries ignored: %s"
,
envOpsWSTrustedProxies
,
strings
.
Join
(
invalid
,
", "
))
}
cfg
.
TrustedProxies
=
prefixes
}
if
v
:=
strings
.
TrimSpace
(
os
.
Getenv
(
envOpsWSOriginPolicy
));
v
!=
""
{
normalized
:=
strings
.
ToLower
(
v
)
switch
normalized
{
case
OriginPolicyStrict
,
OriginPolicyPermissive
:
cfg
.
OriginPolicy
=
normalized
default
:
log
.
Printf
(
"[OpsWS] invalid %s=%q (expected %q or %q); using default=%q"
,
envOpsWSOriginPolicy
,
v
,
OriginPolicyStrict
,
OriginPolicyPermissive
,
cfg
.
OriginPolicy
)
}
}
return
cfg
}
func
loadOpsWSRuntimeLimitsFromEnv
()
opsWSRuntimeLimits
{
cfg
:=
opsWSRuntimeLimits
{
MaxConns
:
defaultMaxWSConns
,
MaxConnsPerIP
:
defaultMaxWSConnsPerIP
,
}
if
v
:=
strings
.
TrimSpace
(
os
.
Getenv
(
envOpsWSMaxConns
));
v
!=
""
{
if
parsed
,
err
:=
strconv
.
Atoi
(
v
);
err
==
nil
&&
parsed
>
0
{
cfg
.
MaxConns
=
int32
(
parsed
)
}
else
{
log
.
Printf
(
"[OpsWS] invalid %s=%q (expected int>0); using default=%d"
,
envOpsWSMaxConns
,
v
,
cfg
.
MaxConns
)
}
}
if
v
:=
strings
.
TrimSpace
(
os
.
Getenv
(
envOpsWSMaxConnsPerIP
));
v
!=
""
{
if
parsed
,
err
:=
strconv
.
Atoi
(
v
);
err
==
nil
&&
parsed
>=
0
{
cfg
.
MaxConnsPerIP
=
int32
(
parsed
)
}
else
{
log
.
Printf
(
"[OpsWS] invalid %s=%q (expected int>=0); using default=%d"
,
envOpsWSMaxConnsPerIP
,
v
,
cfg
.
MaxConnsPerIP
)
}
}
return
cfg
}
func
defaultTrustedProxies
()
[]
netip
.
Prefix
{
prefixes
,
_
:=
parseTrustedProxyList
(
"127.0.0.0/8,::1/128"
)
return
prefixes
}
func
parseTrustedProxyList
(
raw
string
)
(
prefixes
[]
netip
.
Prefix
,
invalid
[]
string
)
{
for
_
,
token
:=
range
strings
.
Split
(
raw
,
","
)
{
item
:=
strings
.
TrimSpace
(
token
)
if
item
==
""
{
continue
}
var
(
p
netip
.
Prefix
err
error
)
if
strings
.
Contains
(
item
,
"/"
)
{
p
,
err
=
netip
.
ParsePrefix
(
item
)
}
else
{
var
addr
netip
.
Addr
addr
,
err
=
netip
.
ParseAddr
(
item
)
if
err
==
nil
{
addr
=
addr
.
Unmap
()
bits
:=
128
if
addr
.
Is4
()
{
bits
=
32
}
p
=
netip
.
PrefixFrom
(
addr
,
bits
)
}
}
if
err
!=
nil
||
!
p
.
IsValid
()
{
invalid
=
append
(
invalid
,
item
)
continue
}
prefixes
=
append
(
prefixes
,
p
.
Masked
())
}
return
prefixes
,
invalid
}
func
hostWithoutPort
(
hostport
string
)
string
{
hostport
=
strings
.
TrimSpace
(
hostport
)
if
hostport
==
""
{
return
""
}
if
host
,
_
,
err
:=
net
.
SplitHostPort
(
hostport
);
err
==
nil
{
return
host
}
if
strings
.
HasPrefix
(
hostport
,
"["
)
&&
strings
.
HasSuffix
(
hostport
,
"]"
)
{
return
strings
.
Trim
(
hostport
,
"[]"
)
}
parts
:=
strings
.
Split
(
hostport
,
":"
)
return
parts
[
0
]
}
backend/internal/handler/admin/promo_handler.go
0 → 100644
View file @
b9b4db3d
package
admin
import
(
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// PromoHandler handles admin promo code management
type
PromoHandler
struct
{
promoService
*
service
.
PromoService
}
// NewPromoHandler creates a new admin promo handler
func
NewPromoHandler
(
promoService
*
service
.
PromoService
)
*
PromoHandler
{
return
&
PromoHandler
{
promoService
:
promoService
,
}
}
// CreatePromoCodeRequest represents create promo code request
type
CreatePromoCodeRequest
struct
{
Code
string
`json:"code"`
// 可选,为空则自动生成
BonusAmount
float64
`json:"bonus_amount" binding:"required,min=0"`
// 赠送余额
MaxUses
int
`json:"max_uses" binding:"min=0"`
// 最大使用次数,0=无限
ExpiresAt
*
int64
`json:"expires_at"`
// 过期时间戳(秒)
Notes
string
`json:"notes"`
// 备注
}
// UpdatePromoCodeRequest represents update promo code request
type
UpdatePromoCodeRequest
struct
{
Code
*
string
`json:"code"`
BonusAmount
*
float64
`json:"bonus_amount" binding:"omitempty,min=0"`
MaxUses
*
int
`json:"max_uses" binding:"omitempty,min=0"`
Status
*
string
`json:"status" binding:"omitempty,oneof=active disabled"`
ExpiresAt
*
int64
`json:"expires_at"`
Notes
*
string
`json:"notes"`
}
// List handles listing all promo codes with pagination
// GET /api/v1/admin/promo-codes
func
(
h
*
PromoHandler
)
List
(
c
*
gin
.
Context
)
{
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
status
:=
c
.
Query
(
"status"
)
search
:=
strings
.
TrimSpace
(
c
.
Query
(
"search"
))
if
len
(
search
)
>
100
{
search
=
search
[
:
100
]
}
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
,
}
codes
,
paginationResult
,
err
:=
h
.
promoService
.
List
(
c
.
Request
.
Context
(),
params
,
status
,
search
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
out
:=
make
([]
dto
.
PromoCode
,
0
,
len
(
codes
))
for
i
:=
range
codes
{
out
=
append
(
out
,
*
dto
.
PromoCodeFromService
(
&
codes
[
i
]))
}
response
.
Paginated
(
c
,
out
,
paginationResult
.
Total
,
page
,
pageSize
)
}
// GetByID handles getting a promo code by ID
// GET /api/v1/admin/promo-codes/:id
func
(
h
*
PromoHandler
)
GetByID
(
c
*
gin
.
Context
)
{
codeID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid promo code ID"
)
return
}
code
,
err
:=
h
.
promoService
.
GetByID
(
c
.
Request
.
Context
(),
codeID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
dto
.
PromoCodeFromService
(
code
))
}
// Create handles creating a new promo code
// POST /api/v1/admin/promo-codes
func
(
h
*
PromoHandler
)
Create
(
c
*
gin
.
Context
)
{
var
req
CreatePromoCodeRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
input
:=
&
service
.
CreatePromoCodeInput
{
Code
:
req
.
Code
,
BonusAmount
:
req
.
BonusAmount
,
MaxUses
:
req
.
MaxUses
,
Notes
:
req
.
Notes
,
}
if
req
.
ExpiresAt
!=
nil
{
t
:=
time
.
Unix
(
*
req
.
ExpiresAt
,
0
)
input
.
ExpiresAt
=
&
t
}
code
,
err
:=
h
.
promoService
.
Create
(
c
.
Request
.
Context
(),
input
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
dto
.
PromoCodeFromService
(
code
))
}
// Update handles updating a promo code
// PUT /api/v1/admin/promo-codes/:id
func
(
h
*
PromoHandler
)
Update
(
c
*
gin
.
Context
)
{
codeID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid promo code ID"
)
return
}
var
req
UpdatePromoCodeRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
input
:=
&
service
.
UpdatePromoCodeInput
{
Code
:
req
.
Code
,
BonusAmount
:
req
.
BonusAmount
,
MaxUses
:
req
.
MaxUses
,
Status
:
req
.
Status
,
Notes
:
req
.
Notes
,
}
if
req
.
ExpiresAt
!=
nil
{
if
*
req
.
ExpiresAt
==
0
{
// 0 表示清除过期时间
input
.
ExpiresAt
=
nil
}
else
{
t
:=
time
.
Unix
(
*
req
.
ExpiresAt
,
0
)
input
.
ExpiresAt
=
&
t
}
}
code
,
err
:=
h
.
promoService
.
Update
(
c
.
Request
.
Context
(),
codeID
,
input
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
dto
.
PromoCodeFromService
(
code
))
}
// Delete handles deleting a promo code
// DELETE /api/v1/admin/promo-codes/:id
func
(
h
*
PromoHandler
)
Delete
(
c
*
gin
.
Context
)
{
codeID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid promo code ID"
)
return
}
err
=
h
.
promoService
.
Delete
(
c
.
Request
.
Context
(),
codeID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"Promo code deleted successfully"
})
}
// GetUsages handles getting usage records for a promo code
// GET /api/v1/admin/promo-codes/:id/usages
func
(
h
*
PromoHandler
)
GetUsages
(
c
*
gin
.
Context
)
{
codeID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid promo code ID"
)
return
}
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
params
:=
pagination
.
PaginationParams
{
Page
:
page
,
PageSize
:
pageSize
,
}
usages
,
paginationResult
,
err
:=
h
.
promoService
.
ListUsages
(
c
.
Request
.
Context
(),
codeID
,
params
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
out
:=
make
([]
dto
.
PromoCodeUsage
,
0
,
len
(
usages
))
for
i
:=
range
usages
{
out
=
append
(
out
,
*
dto
.
PromoCodeUsageFromService
(
&
usages
[
i
]))
}
response
.
Paginated
(
c
,
out
,
paginationResult
.
Total
,
page
,
pageSize
)
}
backend/internal/handler/admin/proxy_handler.go
View file @
b9b4db3d
...
...
@@ -196,6 +196,28 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"Proxy deleted successfully"
})
}
// BatchDelete handles batch deleting proxies
// POST /api/v1/admin/proxies/batch-delete
func
(
h
*
ProxyHandler
)
BatchDelete
(
c
*
gin
.
Context
)
{
type
BatchDeleteRequest
struct
{
IDs
[]
int64
`json:"ids" binding:"required,min=1"`
}
var
req
BatchDeleteRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
result
,
err
:=
h
.
adminService
.
BatchDeleteProxies
(
c
.
Request
.
Context
(),
req
.
IDs
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
result
)
}
// Test handles testing proxy connectivity
// POST /api/v1/admin/proxies/:id/test
func
(
h
*
ProxyHandler
)
Test
(
c
*
gin
.
Context
)
{
...
...
@@ -243,19 +265,17 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
return
}
page
,
pageSize
:=
response
.
ParsePagination
(
c
)
accounts
,
total
,
err
:=
h
.
adminService
.
GetProxyAccounts
(
c
.
Request
.
Context
(),
proxyID
,
page
,
pageSize
)
accounts
,
err
:=
h
.
adminService
.
GetProxyAccounts
(
c
.
Request
.
Context
(),
proxyID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
out
:=
make
([]
dto
.
Account
,
0
,
len
(
accounts
))
out
:=
make
([]
dto
.
Proxy
Account
Summary
,
0
,
len
(
accounts
))
for
i
:=
range
accounts
{
out
=
append
(
out
,
*
dto
.
AccountFromService
(
&
accounts
[
i
]))
out
=
append
(
out
,
*
dto
.
Proxy
Account
Summary
FromService
(
&
accounts
[
i
]))
}
response
.
Paginated
(
c
,
out
,
total
,
page
,
pageSize
)
response
.
Success
(
c
,
out
)
}
// BatchCreateProxyItem represents a single proxy in batch create request
...
...
backend/internal/handler/admin/setting_handler.go
View file @
b9b4db3d
...
...
@@ -19,14 +19,16 @@ type SettingHandler struct {
settingService
*
service
.
SettingService
emailService
*
service
.
EmailService
turnstileService
*
service
.
TurnstileService
opsService
*
service
.
OpsService
}
// NewSettingHandler 创建系统设置处理器
func
NewSettingHandler
(
settingService
*
service
.
SettingService
,
emailService
*
service
.
EmailService
,
turnstileService
*
service
.
TurnstileService
)
*
SettingHandler
{
func
NewSettingHandler
(
settingService
*
service
.
SettingService
,
emailService
*
service
.
EmailService
,
turnstileService
*
service
.
TurnstileService
,
opsService
*
service
.
OpsService
)
*
SettingHandler
{
return
&
SettingHandler
{
settingService
:
settingService
,
emailService
:
emailService
,
turnstileService
:
turnstileService
,
opsService
:
opsService
,
}
}
...
...
@@ -39,6 +41,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
return
}
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled
:=
h
.
opsService
!=
nil
&&
h
.
opsService
.
IsMonitoringEnabled
(
c
.
Request
.
Context
())
response
.
Success
(
c
,
dto
.
SystemSettings
{
RegistrationEnabled
:
settings
.
RegistrationEnabled
,
EmailVerifyEnabled
:
settings
.
EmailVerifyEnabled
,
...
...
@@ -62,6 +67,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
APIBaseURL
:
settings
.
APIBaseURL
,
ContactInfo
:
settings
.
ContactInfo
,
DocURL
:
settings
.
DocURL
,
HomeContent
:
settings
.
HomeContent
,
DefaultConcurrency
:
settings
.
DefaultConcurrency
,
DefaultBalance
:
settings
.
DefaultBalance
,
EnableModelFallback
:
settings
.
EnableModelFallback
,
...
...
@@ -71,6 +77,10 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
FallbackModelAntigravity
:
settings
.
FallbackModelAntigravity
,
EnableIdentityPatch
:
settings
.
EnableIdentityPatch
,
IdentityPatchPrompt
:
settings
.
IdentityPatchPrompt
,
OpsMonitoringEnabled
:
opsEnabled
&&
settings
.
OpsMonitoringEnabled
,
OpsRealtimeMonitoringEnabled
:
settings
.
OpsRealtimeMonitoringEnabled
,
OpsQueryModeDefault
:
settings
.
OpsQueryModeDefault
,
OpsMetricsIntervalSeconds
:
settings
.
OpsMetricsIntervalSeconds
,
})
}
...
...
@@ -94,7 +104,7 @@ type UpdateSettingsRequest struct {
TurnstileSiteKey
string
`json:"turnstile_site_key"`
TurnstileSecretKey
string
`json:"turnstile_secret_key"`
// LinuxDo Connect OAuth 登录
(终端用户 SSO)
// LinuxDo Connect OAuth 登录
LinuxDoConnectEnabled
bool
`json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID
string
`json:"linuxdo_connect_client_id"`
LinuxDoConnectClientSecret
string
`json:"linuxdo_connect_client_secret"`
...
...
@@ -107,6 +117,7 @@ type UpdateSettingsRequest struct {
APIBaseURL
string
`json:"api_base_url"`
ContactInfo
string
`json:"contact_info"`
DocURL
string
`json:"doc_url"`
HomeContent
string
`json:"home_content"`
// 默认配置
DefaultConcurrency
int
`json:"default_concurrency"`
...
...
@@ -122,6 +133,12 @@ type UpdateSettingsRequest struct {
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch
bool
`json:"enable_identity_patch"`
IdentityPatchPrompt
string
`json:"identity_patch_prompt"`
// Ops monitoring (vNext)
OpsMonitoringEnabled
*
bool
`json:"ops_monitoring_enabled"`
OpsRealtimeMonitoringEnabled
*
bool
`json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault
*
string
`json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds
*
int
`json:"ops_metrics_interval_seconds"`
}
// UpdateSettings 更新系统设置
...
...
@@ -206,6 +223,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// Ops metrics collector interval validation (seconds).
if
req
.
OpsMetricsIntervalSeconds
!=
nil
{
v
:=
*
req
.
OpsMetricsIntervalSeconds
if
v
<
60
{
v
=
60
}
if
v
>
3600
{
v
=
3600
}
req
.
OpsMetricsIntervalSeconds
=
&
v
}
settings
:=
&
service
.
SystemSettings
{
RegistrationEnabled
:
req
.
RegistrationEnabled
,
EmailVerifyEnabled
:
req
.
EmailVerifyEnabled
,
...
...
@@ -229,6 +258,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
APIBaseURL
:
req
.
APIBaseURL
,
ContactInfo
:
req
.
ContactInfo
,
DocURL
:
req
.
DocURL
,
HomeContent
:
req
.
HomeContent
,
DefaultConcurrency
:
req
.
DefaultConcurrency
,
DefaultBalance
:
req
.
DefaultBalance
,
EnableModelFallback
:
req
.
EnableModelFallback
,
...
...
@@ -238,6 +268,30 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelAntigravity
:
req
.
FallbackModelAntigravity
,
EnableIdentityPatch
:
req
.
EnableIdentityPatch
,
IdentityPatchPrompt
:
req
.
IdentityPatchPrompt
,
OpsMonitoringEnabled
:
func
()
bool
{
if
req
.
OpsMonitoringEnabled
!=
nil
{
return
*
req
.
OpsMonitoringEnabled
}
return
previousSettings
.
OpsMonitoringEnabled
}(),
OpsRealtimeMonitoringEnabled
:
func
()
bool
{
if
req
.
OpsRealtimeMonitoringEnabled
!=
nil
{
return
*
req
.
OpsRealtimeMonitoringEnabled
}
return
previousSettings
.
OpsRealtimeMonitoringEnabled
}(),
OpsQueryModeDefault
:
func
()
string
{
if
req
.
OpsQueryModeDefault
!=
nil
{
return
*
req
.
OpsQueryModeDefault
}
return
previousSettings
.
OpsQueryModeDefault
}(),
OpsMetricsIntervalSeconds
:
func
()
int
{
if
req
.
OpsMetricsIntervalSeconds
!=
nil
{
return
*
req
.
OpsMetricsIntervalSeconds
}
return
previousSettings
.
OpsMetricsIntervalSeconds
}(),
}
if
err
:=
h
.
settingService
.
UpdateSettings
(
c
.
Request
.
Context
(),
settings
);
err
!=
nil
{
...
...
@@ -277,6 +331,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
APIBaseURL
:
updatedSettings
.
APIBaseURL
,
ContactInfo
:
updatedSettings
.
ContactInfo
,
DocURL
:
updatedSettings
.
DocURL
,
HomeContent
:
updatedSettings
.
HomeContent
,
DefaultConcurrency
:
updatedSettings
.
DefaultConcurrency
,
DefaultBalance
:
updatedSettings
.
DefaultBalance
,
EnableModelFallback
:
updatedSettings
.
EnableModelFallback
,
...
...
@@ -286,6 +341,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelAntigravity
:
updatedSettings
.
FallbackModelAntigravity
,
EnableIdentityPatch
:
updatedSettings
.
EnableIdentityPatch
,
IdentityPatchPrompt
:
updatedSettings
.
IdentityPatchPrompt
,
OpsMonitoringEnabled
:
updatedSettings
.
OpsMonitoringEnabled
,
OpsRealtimeMonitoringEnabled
:
updatedSettings
.
OpsRealtimeMonitoringEnabled
,
OpsQueryModeDefault
:
updatedSettings
.
OpsQueryModeDefault
,
OpsMetricsIntervalSeconds
:
updatedSettings
.
OpsMetricsIntervalSeconds
,
})
}
...
...
@@ -377,6 +436,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if
before
.
DocURL
!=
after
.
DocURL
{
changed
=
append
(
changed
,
"doc_url"
)
}
if
before
.
HomeContent
!=
after
.
HomeContent
{
changed
=
append
(
changed
,
"home_content"
)
}
if
before
.
DefaultConcurrency
!=
after
.
DefaultConcurrency
{
changed
=
append
(
changed
,
"default_concurrency"
)
}
...
...
@@ -404,6 +466,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if
before
.
IdentityPatchPrompt
!=
after
.
IdentityPatchPrompt
{
changed
=
append
(
changed
,
"identity_patch_prompt"
)
}
if
before
.
OpsMonitoringEnabled
!=
after
.
OpsMonitoringEnabled
{
changed
=
append
(
changed
,
"ops_monitoring_enabled"
)
}
if
before
.
OpsRealtimeMonitoringEnabled
!=
after
.
OpsRealtimeMonitoringEnabled
{
changed
=
append
(
changed
,
"ops_realtime_monitoring_enabled"
)
}
if
before
.
OpsQueryModeDefault
!=
after
.
OpsQueryModeDefault
{
changed
=
append
(
changed
,
"ops_query_mode_default"
)
}
if
before
.
OpsMetricsIntervalSeconds
!=
after
.
OpsMetricsIntervalSeconds
{
changed
=
append
(
changed
,
"ops_metrics_interval_seconds"
)
}
return
changed
}
...
...
@@ -580,3 +654,68 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"Admin API key deleted"
})
}
// GetStreamTimeoutSettings 获取流超时处理配置
// GET /api/v1/admin/settings/stream-timeout
func
(
h
*
SettingHandler
)
GetStreamTimeoutSettings
(
c
*
gin
.
Context
)
{
settings
,
err
:=
h
.
settingService
.
GetStreamTimeoutSettings
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
dto
.
StreamTimeoutSettings
{
Enabled
:
settings
.
Enabled
,
Action
:
settings
.
Action
,
TempUnschedMinutes
:
settings
.
TempUnschedMinutes
,
ThresholdCount
:
settings
.
ThresholdCount
,
ThresholdWindowMinutes
:
settings
.
ThresholdWindowMinutes
,
})
}
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
type
UpdateStreamTimeoutSettingsRequest
struct
{
Enabled
bool
`json:"enabled"`
Action
string
`json:"action"`
TempUnschedMinutes
int
`json:"temp_unsched_minutes"`
ThresholdCount
int
`json:"threshold_count"`
ThresholdWindowMinutes
int
`json:"threshold_window_minutes"`
}
// UpdateStreamTimeoutSettings 更新流超时处理配置
// PUT /api/v1/admin/settings/stream-timeout
func
(
h
*
SettingHandler
)
UpdateStreamTimeoutSettings
(
c
*
gin
.
Context
)
{
var
req
UpdateStreamTimeoutSettingsRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
settings
:=
&
service
.
StreamTimeoutSettings
{
Enabled
:
req
.
Enabled
,
Action
:
req
.
Action
,
TempUnschedMinutes
:
req
.
TempUnschedMinutes
,
ThresholdCount
:
req
.
ThresholdCount
,
ThresholdWindowMinutes
:
req
.
ThresholdWindowMinutes
,
}
if
err
:=
h
.
settingService
.
SetStreamTimeoutSettings
(
c
.
Request
.
Context
(),
settings
);
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
// 重新获取设置返回
updatedSettings
,
err
:=
h
.
settingService
.
GetStreamTimeoutSettings
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
dto
.
StreamTimeoutSettings
{
Enabled
:
updatedSettings
.
Enabled
,
Action
:
updatedSettings
.
Action
,
TempUnschedMinutes
:
updatedSettings
.
TempUnschedMinutes
,
ThresholdCount
:
updatedSettings
.
ThresholdCount
,
ThresholdWindowMinutes
:
updatedSettings
.
ThresholdWindowMinutes
,
})
}
backend/internal/handler/api_key_handler.go
View file @
b9b4db3d
...
...
@@ -30,6 +30,8 @@ type CreateAPIKeyRequest struct {
Name
string
`json:"name" binding:"required"`
GroupID
*
int64
`json:"group_id"`
// nullable
CustomKey
*
string
`json:"custom_key"`
// 可选的自定义key
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
}
// UpdateAPIKeyRequest represents the update API key request payload
...
...
@@ -37,6 +39,8 @@ type UpdateAPIKeyRequest struct {
Name
string
`json:"name"`
GroupID
*
int64
`json:"group_id"`
Status
string
`json:"status" binding:"omitempty,oneof=active inactive"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
// IP 白名单
IPBlacklist
[]
string
`json:"ip_blacklist"`
// IP 黑名单
}
// List handles listing user's API keys with pagination
...
...
@@ -113,6 +117,8 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
Name
:
req
.
Name
,
GroupID
:
req
.
GroupID
,
CustomKey
:
req
.
CustomKey
,
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
}
key
,
err
:=
h
.
apiKeyService
.
Create
(
c
.
Request
.
Context
(),
subject
.
UserID
,
svcReq
)
if
err
!=
nil
{
...
...
@@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return
}
svcReq
:=
service
.
UpdateAPIKeyRequest
{}
svcReq
:=
service
.
UpdateAPIKeyRequest
{
IPWhitelist
:
req
.
IPWhitelist
,
IPBlacklist
:
req
.
IPBlacklist
,
}
if
req
.
Name
!=
""
{
svcReq
.
Name
=
&
req
.
Name
}
...
...
backend/internal/handler/auth_handler.go
View file @
b9b4db3d
...
...
@@ -3,6 +3,7 @@ package handler
import
(
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -16,15 +17,17 @@ type AuthHandler struct {
authService
*
service
.
AuthService
userService
*
service
.
UserService
settingSvc
*
service
.
SettingService
promoService
*
service
.
PromoService
}
// NewAuthHandler creates a new AuthHandler
func
NewAuthHandler
(
cfg
*
config
.
Config
,
authService
*
service
.
AuthService
,
userService
*
service
.
UserService
,
settingService
*
service
.
SettingService
)
*
AuthHandler
{
func
NewAuthHandler
(
cfg
*
config
.
Config
,
authService
*
service
.
AuthService
,
userService
*
service
.
UserService
,
settingService
*
service
.
SettingService
,
promoService
*
service
.
PromoService
)
*
AuthHandler
{
return
&
AuthHandler
{
cfg
:
cfg
,
authService
:
authService
,
userService
:
userService
,
settingSvc
:
settingService
,
promoService
:
promoService
,
}
}
...
...
@@ -34,6 +37,7 @@ type RegisterRequest struct {
Password
string
`json:"password" binding:"required,min=6"`
VerifyCode
string
`json:"verify_code"`
TurnstileToken
string
`json:"turnstile_token"`
PromoCode
string
`json:"promo_code"`
// 注册优惠码
}
// SendVerifyCodeRequest 发送验证码请求
...
...
@@ -73,13 +77,13 @@ func (h *AuthHandler) Register(c *gin.Context) {
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if
req
.
VerifyCode
==
""
{
if
err
:=
h
.
authService
.
VerifyTurnstile
(
c
.
Request
.
Context
(),
req
.
TurnstileToken
,
c
.
ClientIP
());
err
!=
nil
{
if
err
:=
h
.
authService
.
VerifyTurnstile
(
c
.
Request
.
Context
(),
req
.
TurnstileToken
,
ip
.
Get
ClientIP
(
c
));
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
}
token
,
user
,
err
:=
h
.
authService
.
RegisterWithVerification
(
c
.
Request
.
Context
(),
req
.
Email
,
req
.
Password
,
req
.
VerifyCode
)
token
,
user
,
err
:=
h
.
authService
.
RegisterWithVerification
(
c
.
Request
.
Context
(),
req
.
Email
,
req
.
Password
,
req
.
VerifyCode
,
req
.
PromoCode
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -102,7 +106,7 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
}
// Turnstile 验证
if
err
:=
h
.
authService
.
VerifyTurnstile
(
c
.
Request
.
Context
(),
req
.
TurnstileToken
,
c
.
ClientIP
());
err
!=
nil
{
if
err
:=
h
.
authService
.
VerifyTurnstile
(
c
.
Request
.
Context
(),
req
.
TurnstileToken
,
ip
.
Get
ClientIP
(
c
));
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
...
...
@@ -129,7 +133,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
}
// Turnstile 验证
if
err
:=
h
.
authService
.
VerifyTurnstile
(
c
.
Request
.
Context
(),
req
.
TurnstileToken
,
c
.
ClientIP
());
err
!=
nil
{
if
err
:=
h
.
authService
.
VerifyTurnstile
(
c
.
Request
.
Context
(),
req
.
TurnstileToken
,
ip
.
Get
ClientIP
(
c
));
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
...
...
@@ -174,3 +178,63 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
response
.
Success
(
c
,
UserResponse
{
User
:
dto
.
UserFromService
(
user
),
RunMode
:
runMode
})
}
// ValidatePromoCodeRequest 验证优惠码请求
type
ValidatePromoCodeRequest
struct
{
Code
string
`json:"code" binding:"required"`
}
// ValidatePromoCodeResponse 验证优惠码响应
type
ValidatePromoCodeResponse
struct
{
Valid
bool
`json:"valid"`
BonusAmount
float64
`json:"bonus_amount,omitempty"`
ErrorCode
string
`json:"error_code,omitempty"`
Message
string
`json:"message,omitempty"`
}
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
// POST /api/v1/auth/validate-promo-code
func
(
h
*
AuthHandler
)
ValidatePromoCode
(
c
*
gin
.
Context
)
{
var
req
ValidatePromoCodeRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
promoCode
,
err
:=
h
.
promoService
.
ValidatePromoCode
(
c
.
Request
.
Context
(),
req
.
Code
)
if
err
!=
nil
{
// 根据错误类型返回对应的错误码
errorCode
:=
"PROMO_CODE_INVALID"
switch
err
{
case
service
.
ErrPromoCodeNotFound
:
errorCode
=
"PROMO_CODE_NOT_FOUND"
case
service
.
ErrPromoCodeExpired
:
errorCode
=
"PROMO_CODE_EXPIRED"
case
service
.
ErrPromoCodeDisabled
:
errorCode
=
"PROMO_CODE_DISABLED"
case
service
.
ErrPromoCodeMaxUsed
:
errorCode
=
"PROMO_CODE_MAX_USED"
case
service
.
ErrPromoCodeAlreadyUsed
:
errorCode
=
"PROMO_CODE_ALREADY_USED"
}
response
.
Success
(
c
,
ValidatePromoCodeResponse
{
Valid
:
false
,
ErrorCode
:
errorCode
,
})
return
}
if
promoCode
==
nil
{
response
.
Success
(
c
,
ValidatePromoCodeResponse
{
Valid
:
false
,
ErrorCode
:
"PROMO_CODE_INVALID"
,
})
return
}
response
.
Success
(
c
,
ValidatePromoCodeResponse
{
Valid
:
true
,
BonusAmount
:
promoCode
.
BonusAmount
,
})
}
backend/internal/handler/dto/mappers.go
View file @
b9b4db3d
...
...
@@ -59,6 +59,8 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
Name
:
k
.
Name
,
GroupID
:
k
.
GroupID
,
Status
:
k
.
Status
,
IPWhitelist
:
k
.
IPWhitelist
,
IPBlacklist
:
k
.
IPBlacklist
,
CreatedAt
:
k
.
CreatedAt
,
UpdatedAt
:
k
.
UpdatedAt
,
User
:
UserFromServiceShallow
(
k
.
User
),
...
...
@@ -87,6 +89,8 @@ func GroupFromServiceShallow(g *service.Group) *Group {
ImagePrice4K
:
g
.
ImagePrice4K
,
ClaudeCodeOnly
:
g
.
ClaudeCodeOnly
,
FallbackGroupID
:
g
.
FallbackGroupID
,
ModelRouting
:
g
.
ModelRouting
,
ModelRoutingEnabled
:
g
.
ModelRoutingEnabled
,
CreatedAt
:
g
.
CreatedAt
,
UpdatedAt
:
g
.
UpdatedAt
,
AccountCount
:
g
.
AccountCount
,
...
...
@@ -112,7 +116,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if
a
==
nil
{
return
nil
}
return
&
Account
{
out
:=
&
Account
{
ID
:
a
.
ID
,
Name
:
a
.
Name
,
Notes
:
a
.
Notes
,
...
...
@@ -123,6 +127,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
ProxyID
:
a
.
ProxyID
,
Concurrency
:
a
.
Concurrency
,
Priority
:
a
.
Priority
,
RateMultiplier
:
a
.
BillingRateMultiplier
(),
Status
:
a
.
Status
,
ErrorMessage
:
a
.
ErrorMessage
,
LastUsedAt
:
a
.
LastUsedAt
,
...
...
@@ -141,6 +146,24 @@ func AccountFromServiceShallow(a *service.Account) *Account {
SessionWindowStatus
:
a
.
SessionWindowStatus
,
GroupIDs
:
a
.
GroupIDs
,
}
// 提取 5h 窗口费用控制和会话数量控制配置(仅 Anthropic OAuth/SetupToken 账号有效)
if
a
.
IsAnthropicOAuthOrSetupToken
()
{
if
limit
:=
a
.
GetWindowCostLimit
();
limit
>
0
{
out
.
WindowCostLimit
=
&
limit
}
if
reserve
:=
a
.
GetWindowCostStickyReserve
();
reserve
>
0
{
out
.
WindowCostStickyReserve
=
&
reserve
}
if
maxSessions
:=
a
.
GetMaxSessions
();
maxSessions
>
0
{
out
.
MaxSessions
=
&
maxSessions
}
if
idleTimeout
:=
a
.
GetSessionIdleTimeoutMinutes
();
idleTimeout
>
0
{
out
.
SessionIdleTimeoutMin
=
&
idleTimeout
}
}
return
out
}
func
AccountFromService
(
a
*
service
.
Account
)
*
Account
{
...
...
@@ -212,6 +235,27 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
return
&
ProxyWithAccountCount
{
Proxy
:
*
ProxyFromService
(
&
p
.
Proxy
),
AccountCount
:
p
.
AccountCount
,
LatencyMs
:
p
.
LatencyMs
,
LatencyStatus
:
p
.
LatencyStatus
,
LatencyMessage
:
p
.
LatencyMessage
,
IPAddress
:
p
.
IPAddress
,
Country
:
p
.
Country
,
CountryCode
:
p
.
CountryCode
,
Region
:
p
.
Region
,
City
:
p
.
City
,
}
}
func
ProxyAccountSummaryFromService
(
a
*
service
.
ProxyAccountSummary
)
*
ProxyAccountSummary
{
if
a
==
nil
{
return
nil
}
return
&
ProxyAccountSummary
{
ID
:
a
.
ID
,
Name
:
a
.
Name
,
Platform
:
a
.
Platform
,
Type
:
a
.
Type
,
Notes
:
a
.
Notes
,
}
}
...
...
@@ -250,11 +294,12 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
// The account parameter allows caller to control what Account info is included.
func
usageLogFromServiceBase
(
l
*
service
.
UsageLog
,
account
*
AccountSummary
)
*
UsageLog
{
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
func
usageLogFromServiceBase
(
l
*
service
.
UsageLog
,
account
*
AccountSummary
,
includeIPAddress
bool
)
*
UsageLog
{
if
l
==
nil
{
return
nil
}
re
turn
&
UsageLog
{
re
sult
:=
&
UsageLog
{
ID
:
l
.
ID
,
UserID
:
l
.
UserID
,
APIKeyID
:
l
.
APIKeyID
,
...
...
@@ -276,6 +321,7 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag
TotalCost
:
l
.
TotalCost
,
ActualCost
:
l
.
ActualCost
,
RateMultiplier
:
l
.
RateMultiplier
,
AccountRateMultiplier
:
l
.
AccountRateMultiplier
,
BillingType
:
l
.
BillingType
,
Stream
:
l
.
Stream
,
DurationMs
:
l
.
DurationMs
,
...
...
@@ -290,21 +336,26 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag
Group
:
GroupFromServiceShallow
(
l
.
Group
),
Subscription
:
UserSubscriptionFromService
(
l
.
Subscription
),
}
// IP 地址仅对管理员可见
if
includeIPAddress
{
result
.
IPAddress
=
l
.
IPAddress
}
return
result
}
// UsageLogFromService converts a service UsageLog to DTO for regular users.
// It excludes Account details - users should not see
account information
.
// It excludes Account details
and IP address
- users should not see
these
.
func
UsageLogFromService
(
l
*
service
.
UsageLog
)
*
UsageLog
{
return
usageLogFromServiceBase
(
l
,
nil
)
return
usageLogFromServiceBase
(
l
,
nil
,
false
)
}
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
// It includes minimal Account info (ID, Name only).
// It includes minimal Account info (ID, Name only)
and IP address
.
func
UsageLogFromServiceAdmin
(
l
*
service
.
UsageLog
)
*
UsageLog
{
if
l
==
nil
{
return
nil
}
return
usageLogFromServiceBase
(
l
,
AccountSummaryFromService
(
l
.
Account
))
return
usageLogFromServiceBase
(
l
,
AccountSummaryFromService
(
l
.
Account
)
,
true
)
}
func
SettingFromService
(
s
*
service
.
Setting
)
*
Setting
{
...
...
@@ -362,3 +413,35 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
Errors
:
r
.
Errors
,
}
}
func
PromoCodeFromService
(
pc
*
service
.
PromoCode
)
*
PromoCode
{
if
pc
==
nil
{
return
nil
}
return
&
PromoCode
{
ID
:
pc
.
ID
,
Code
:
pc
.
Code
,
BonusAmount
:
pc
.
BonusAmount
,
MaxUses
:
pc
.
MaxUses
,
UsedCount
:
pc
.
UsedCount
,
Status
:
pc
.
Status
,
ExpiresAt
:
pc
.
ExpiresAt
,
Notes
:
pc
.
Notes
,
CreatedAt
:
pc
.
CreatedAt
,
UpdatedAt
:
pc
.
UpdatedAt
,
}
}
func
PromoCodeUsageFromService
(
u
*
service
.
PromoCodeUsage
)
*
PromoCodeUsage
{
if
u
==
nil
{
return
nil
}
return
&
PromoCodeUsage
{
ID
:
u
.
ID
,
PromoCodeID
:
u
.
PromoCodeID
,
UserID
:
u
.
UserID
,
BonusAmount
:
u
.
BonusAmount
,
UsedAt
:
u
.
UsedAt
,
User
:
UserFromServiceShallow
(
u
.
User
),
}
}
backend/internal/handler/dto/settings.go
View file @
b9b4db3d
...
...
@@ -28,6 +28,7 @@ type SystemSettings struct {
APIBaseURL
string
`json:"api_base_url"`
ContactInfo
string
`json:"contact_info"`
DocURL
string
`json:"doc_url"`
HomeContent
string
`json:"home_content"`
DefaultConcurrency
int
`json:"default_concurrency"`
DefaultBalance
float64
`json:"default_balance"`
...
...
@@ -42,6 +43,12 @@ type SystemSettings struct {
// Identity patch configuration (Claude -> Gemini)
EnableIdentityPatch
bool
`json:"enable_identity_patch"`
IdentityPatchPrompt
string
`json:"identity_patch_prompt"`
// Ops monitoring (vNext)
OpsMonitoringEnabled
bool
`json:"ops_monitoring_enabled"`
OpsRealtimeMonitoringEnabled
bool
`json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault
string
`json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds
int
`json:"ops_metrics_interval_seconds"`
}
type
PublicSettings
struct
{
...
...
@@ -55,6 +62,16 @@ type PublicSettings struct {
APIBaseURL
string
`json:"api_base_url"`
ContactInfo
string
`json:"contact_info"`
DocURL
string
`json:"doc_url"`
HomeContent
string
`json:"home_content"`
LinuxDoOAuthEnabled
bool
`json:"linuxdo_oauth_enabled"`
Version
string
`json:"version"`
}
// StreamTimeoutSettings 流超时处理配置 DTO
type
StreamTimeoutSettings
struct
{
Enabled
bool
`json:"enabled"`
Action
string
`json:"action"`
TempUnschedMinutes
int
`json:"temp_unsched_minutes"`
ThresholdCount
int
`json:"threshold_count"`
ThresholdWindowMinutes
int
`json:"threshold_window_minutes"`
}
backend/internal/handler/dto/types.go
View file @
b9b4db3d
...
...
@@ -26,6 +26,8 @@ type APIKey struct {
Name
string
`json:"name"`
GroupID
*
int64
`json:"group_id"`
Status
string
`json:"status"`
IPWhitelist
[]
string
`json:"ip_whitelist"`
IPBlacklist
[]
string
`json:"ip_blacklist"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
...
...
@@ -56,6 +58,10 @@ type Group struct {
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
// 模型路由配置(仅 anthropic 平台使用)
ModelRouting
map
[
string
][]
int64
`json:"model_routing"`
ModelRoutingEnabled
bool
`json:"model_routing_enabled"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
...
...
@@ -74,6 +80,7 @@ type Account struct {
ProxyID
*
int64
`json:"proxy_id"`
Concurrency
int
`json:"concurrency"`
Priority
int
`json:"priority"`
RateMultiplier
float64
`json:"rate_multiplier"`
Status
string
`json:"status"`
ErrorMessage
string
`json:"error_message"`
LastUsedAt
*
time
.
Time
`json:"last_used_at"`
...
...
@@ -95,6 +102,16 @@ type Account struct {
SessionWindowEnd
*
time
.
Time
`json:"session_window_end"`
SessionWindowStatus
string
`json:"session_window_status"`
// 5h窗口费用控制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
WindowCostLimit
*
float64
`json:"window_cost_limit,omitempty"`
WindowCostStickyReserve
*
float64
`json:"window_cost_sticky_reserve,omitempty"`
// 会话数量控制(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
MaxSessions
*
int
`json:"max_sessions,omitempty"`
SessionIdleTimeoutMin
*
int
`json:"session_idle_timeout_minutes,omitempty"`
Proxy
*
Proxy
`json:"proxy,omitempty"`
AccountGroups
[]
AccountGroup
`json:"account_groups,omitempty"`
...
...
@@ -128,6 +145,22 @@ type Proxy struct {
type
ProxyWithAccountCount
struct
{
Proxy
AccountCount
int64
`json:"account_count"`
LatencyMs
*
int64
`json:"latency_ms,omitempty"`
LatencyStatus
string
`json:"latency_status,omitempty"`
LatencyMessage
string
`json:"latency_message,omitempty"`
IPAddress
string
`json:"ip_address,omitempty"`
Country
string
`json:"country,omitempty"`
CountryCode
string
`json:"country_code,omitempty"`
Region
string
`json:"region,omitempty"`
City
string
`json:"city,omitempty"`
}
type
ProxyAccountSummary
struct
{
ID
int64
`json:"id"`
Name
string
`json:"name"`
Platform
string
`json:"platform"`
Type
string
`json:"type"`
Notes
*
string
`json:"notes,omitempty"`
}
type
RedeemCode
struct
{
...
...
@@ -174,6 +207,7 @@ type UsageLog struct {
TotalCost
float64
`json:"total_cost"`
ActualCost
float64
`json:"actual_cost"`
RateMultiplier
float64
`json:"rate_multiplier"`
AccountRateMultiplier
*
float64
`json:"account_rate_multiplier"`
BillingType
int8
`json:"billing_type"`
Stream
bool
`json:"stream"`
...
...
@@ -187,6 +221,9 @@ type UsageLog struct {
// User-Agent
UserAgent
*
string
`json:"user_agent"`
// IP 地址(仅管理员可见)
IPAddress
*
string
`json:"ip_address,omitempty"`
CreatedAt
time
.
Time
`json:"created_at"`
User
*
User
`json:"user,omitempty"`
...
...
@@ -245,3 +282,28 @@ type BulkAssignResult struct {
Subscriptions
[]
UserSubscription
`json:"subscriptions"`
Errors
[]
string
`json:"errors"`
}
// PromoCode 注册优惠码
type
PromoCode
struct
{
ID
int64
`json:"id"`
Code
string
`json:"code"`
BonusAmount
float64
`json:"bonus_amount"`
MaxUses
int
`json:"max_uses"`
UsedCount
int
`json:"used_count"`
Status
string
`json:"status"`
ExpiresAt
*
time
.
Time
`json:"expires_at"`
Notes
string
`json:"notes"`
CreatedAt
time
.
Time
`json:"created_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
}
// PromoCodeUsage 优惠码使用记录
type
PromoCodeUsage
struct
{
ID
int64
`json:"id"`
PromoCodeID
int64
`json:"promo_code_id"`
UserID
int64
`json:"user_id"`
BonusAmount
float64
`json:"bonus_amount"`
UsedAt
time
.
Time
`json:"used_at"`
User
*
User
`json:"user,omitempty"`
}
backend/internal/handler/gateway_handler.go
View file @
b9b4db3d
...
...
@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
pkgerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -100,6 +101,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext
(
c
,
body
)
setOpsRequestContext
(
c
,
""
,
false
,
body
)
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
)
...
...
@@ -108,8 +114,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqModel
:=
parsedReq
.
Model
reqStream
:=
parsedReq
.
Stream
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
SetClaudeCodeClientContext
(
c
,
body
)
setOpsRequestContext
(
c
,
reqModel
,
reqStream
,
body
)
// 验证 model 必填
if
reqModel
==
""
{
...
...
@@ -123,12 +128,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
subscription
,
_
:=
middleware2
.
GetSubscriptionFromContext
(
c
)
// 获取 User-Agent
userAgent
:=
c
.
Request
.
UserAgent
()
// 0. 检查wait队列是否已满
maxWait
:=
service
.
CalculateMaxWait
(
subject
.
Concurrency
)
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
,
maxWait
)
waitCounted
:=
false
if
err
!=
nil
{
log
.
Printf
(
"Increment wait count failed: %v"
,
err
)
// On error, allow request to proceed
...
...
@@ -136,8 +139,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
errorResponse
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
)
return
}
// 确保在函数退出时减少wait计数
defer
h
.
concurrencyHelper
.
DecrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
)
if
err
==
nil
&&
canWait
{
waitCounted
=
true
}
// Ensure we decrement if we exit before acquiring the user slot.
defer
func
()
{
if
waitCounted
{
h
.
concurrencyHelper
.
DecrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
)
}
}()
// 1. 首先获取用户并发槽位
userReleaseFunc
,
err
:=
h
.
concurrencyHelper
.
AcquireUserSlotWithWait
(
c
,
subject
.
UserID
,
subject
.
Concurrency
,
reqStream
,
&
streamStarted
)
...
...
@@ -146,6 +156,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h
.
handleConcurrencyError
(
c
,
err
,
"user"
,
streamStarted
)
return
}
// User slot acquired: no longer waiting in the queue.
if
waitCounted
{
h
.
concurrencyHelper
.
DecrementWaitCount
(
c
.
Request
.
Context
(),
subject
.
UserID
)
waitCounted
=
false
}
// 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏
userReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
userReleaseFunc
)
if
userReleaseFunc
!=
nil
{
...
...
@@ -182,7 +197,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus
:=
0
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
""
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
@@ -192,6 +207,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
)
// 检查预热请求拦截(在账号选择后、转发前检查)
if
account
.
IsInterceptWarmupEnabled
()
&&
isWarmupRequest
(
body
)
{
...
...
@@ -208,12 +224,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 3. 获取账号并发槽位
accountReleaseFunc
:=
selection
.
ReleaseFunc
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
accountWaitCounted
:=
false
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
...
...
@@ -221,12 +237,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
if
err
==
nil
&&
canWait
{
accountWaitCounted
=
true
}
// Ensure the wait counter is decremented if we exit before acquiring the slot.
defer
func
()
{
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}()
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
...
...
@@ -237,20 +257,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
// Slot acquired: no longer waiting in queue.
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountReleaseFunc
)
accountWaitRelease
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountWaitRelease
)
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
...
...
@@ -262,19 +283,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
if
switchCount
>=
maxAccountSwitches
{
lastFailoverStatus
=
failoverErr
.
StatusCode
if
switchCount
>=
maxAccountSwitches
{
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
return
}
lastFailoverStatus
=
failoverErr
.
StatusCode
switchCount
++
log
.
Printf
(
"Account %d: upstream error %d, switching account %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
switchCount
,
maxAccountSwitches
)
continue
...
...
@@ -284,8 +301,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent
:=
c
.
GetHeader
(
"User-Agent"
)
clientIP
:=
ip
.
GetClientIP
(
c
)
// 异步记录使用量(subscription已在函数开头获取)
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
string
)
{
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
,
clientIP
string
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
...
...
@@ -295,10 +316,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account
:
usedAccount
,
Subscription
:
subscription
,
UserAgent
:
ua
,
IPAddress
:
clientIP
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
}(
result
,
account
,
userAgent
)
}(
result
,
account
,
userAgent
,
clientIP
)
return
}
}
...
...
@@ -310,7 +332,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for
{
// 选择支持该模型的账号
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
reqModel
,
failedAccountIDs
,
parsedReq
.
MetadataUserID
)
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
(),
streamStarted
)
...
...
@@ -320,6 +342,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
)
// 检查预热请求拦截(在账号选择后、转发前检查)
if
account
.
IsInterceptWarmupEnabled
()
&&
isWarmupRequest
(
body
)
{
...
...
@@ -336,12 +359,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 3. 获取账号并发槽位
accountReleaseFunc
:=
selection
.
ReleaseFunc
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
h
.
handleStreamingAwareError
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts"
,
streamStarted
)
return
}
accountWaitCounted
:=
false
canWait
,
err
:=
h
.
concurrencyHelper
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
...
...
@@ -349,12 +372,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
h
.
handleStreamingAwareError
(
c
,
http
.
StatusTooManyRequests
,
"rate_limit_error"
,
"Too many pending requests, please retry later"
,
streamStarted
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
if
err
==
nil
&&
canWait
{
accountWaitCounted
=
true
}
defer
func
()
{
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}()
accountReleaseFunc
,
err
=
h
.
concurrencyHelper
.
AcquireAccountSlotWithWaitTimeout
(
c
,
...
...
@@ -365,20 +391,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
log
.
Printf
(
"Account concurrency acquire failed: %v"
,
err
)
h
.
handleConcurrencyError
(
c
,
err
,
"account"
,
streamStarted
)
return
}
if
accountWaitCounted
{
h
.
concurrencyHelper
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountReleaseFunc
)
accountWaitRelease
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountWaitRelease
)
// 转发请求 - 根据账号平台分流
var
result
*
service
.
ForwardResult
...
...
@@ -390,19 +416,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
failedAccountIDs
[
account
.
ID
]
=
struct
{}{}
if
switchCount
>=
maxAccountSwitches
{
lastFailoverStatus
=
failoverErr
.
StatusCode
if
switchCount
>=
maxAccountSwitches
{
h
.
handleFailoverExhausted
(
c
,
lastFailoverStatus
,
streamStarted
)
return
}
lastFailoverStatus
=
failoverErr
.
StatusCode
switchCount
++
log
.
Printf
(
"Account %d: upstream error %d, switching account %d/%d"
,
account
.
ID
,
failoverErr
.
StatusCode
,
switchCount
,
maxAccountSwitches
)
continue
...
...
@@ -412,8 +434,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent
:=
c
.
GetHeader
(
"User-Agent"
)
clientIP
:=
ip
.
GetClientIP
(
c
)
// 异步记录使用量(subscription已在函数开头获取)
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
string
)
{
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
,
clientIP
string
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
...
...
@@ -423,10 +449,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account
:
usedAccount
,
Subscription
:
subscription
,
UserAgent
:
ua
,
IPAddress
:
clientIP
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
}(
result
,
account
,
userAgent
)
}(
result
,
account
,
userAgent
,
clientIP
)
return
}
}
...
...
@@ -692,21 +719,22 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
setOpsRequestContext
(
c
,
""
,
false
,
body
)
parsedReq
,
err
:=
service
.
ParseGatewayRequest
(
body
)
if
err
!=
nil
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"Failed to parse request body"
)
return
}
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
SetClaudeCodeClientContext
(
c
,
body
)
// 验证 model 必填
if
parsedReq
.
Model
==
""
{
h
.
errorResponse
(
c
,
http
.
StatusBadRequest
,
"invalid_request_error"
,
"model is required"
)
return
}
setOpsRequestContext
(
c
,
parsedReq
.
Model
,
parsedReq
.
Stream
,
body
)
// 获取订阅信息(可能为nil)
subscription
,
_
:=
middleware2
.
GetSubscriptionFromContext
(
c
)
...
...
@@ -727,6 +755,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h
.
errorResponse
(
c
,
http
.
StatusServiceUnavailable
,
"api_error"
,
"No available accounts: "
+
err
.
Error
())
return
}
setOpsSelectedAccount
(
c
,
account
.
ID
)
// 转发请求(不记录使用量)
if
err
:=
h
.
gatewayService
.
ForwardCountTokens
(
c
.
Request
.
Context
(),
c
,
account
,
parsedReq
);
err
!=
nil
{
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
b9b4db3d
...
...
@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
...
...
@@ -161,25 +162,32 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return
}
setOpsRequestContext
(
c
,
modelName
,
stream
,
body
)
// Get subscription (may be nil)
subscription
,
_
:=
middleware
.
GetSubscriptionFromContext
(
c
)
// 获取 User-Agent
userAgent
:=
c
.
Request
.
UserAgent
()
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency
:=
NewConcurrencyHelper
(
h
.
concurrencyHelper
.
concurrencyService
,
SSEPingFormatNone
,
0
)
// 0) wait queue check
maxWait
:=
service
.
CalculateMaxWait
(
authSubject
.
Concurrency
)
canWait
,
err
:=
geminiConcurrency
.
IncrementWaitCount
(
c
.
Request
.
Context
(),
authSubject
.
UserID
,
maxWait
)
waitCounted
:=
false
if
err
!=
nil
{
log
.
Printf
(
"Increment wait count failed: %v"
,
err
)
}
else
if
!
canWait
{
googleError
(
c
,
http
.
StatusTooManyRequests
,
"Too many pending requests, please retry later"
)
return
}
defer
geminiConcurrency
.
DecrementWaitCount
(
c
.
Request
.
Context
(),
authSubject
.
UserID
)
if
err
==
nil
&&
canWait
{
waitCounted
=
true
}
defer
func
()
{
if
waitCounted
{
geminiConcurrency
.
DecrementWaitCount
(
c
.
Request
.
Context
(),
authSubject
.
UserID
)
}
}()
// 1) user concurrency slot
streamStarted
:=
false
...
...
@@ -188,6 +196,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError
(
c
,
http
.
StatusTooManyRequests
,
err
.
Error
())
return
}
if
waitCounted
{
geminiConcurrency
.
DecrementWaitCount
(
c
.
Request
.
Context
(),
authSubject
.
UserID
)
waitCounted
=
false
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
userReleaseFunc
)
if
userReleaseFunc
!=
nil
{
...
...
@@ -203,10 +215,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 3) select account (sticky session based on request body)
parsedReq
,
_
:=
service
.
ParseGatewayRequest
(
body
)
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
SetClaudeCodeClientContext
(
c
,
body
)
sessionHash
:=
h
.
gatewayService
.
GenerateSessionHash
(
parsedReq
)
sessionKey
:=
sessionHash
if
sessionHash
!=
""
{
...
...
@@ -218,7 +226,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
lastFailoverStatus
:=
0
for
{
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
modelName
,
failedAccountIDs
)
selection
,
err
:=
h
.
gatewayService
.
SelectAccountWithLoadAwareness
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
modelName
,
failedAccountIDs
,
""
)
// Gemini 不使用会话限制
if
err
!=
nil
{
if
len
(
failedAccountIDs
)
==
0
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts: "
+
err
.
Error
())
...
...
@@ -228,15 +236,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return
}
account
:=
selection
.
Account
setOpsSelectedAccount
(
c
,
account
.
ID
)
// 4) account concurrency slot
accountReleaseFunc
:=
selection
.
ReleaseFunc
var
accountWaitRelease
func
()
if
!
selection
.
Acquired
{
if
selection
.
WaitPlan
==
nil
{
googleError
(
c
,
http
.
StatusServiceUnavailable
,
"No available Gemini accounts"
)
return
}
accountWaitCounted
:=
false
canWait
,
err
:=
geminiConcurrency
.
IncrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
,
selection
.
WaitPlan
.
MaxWaiting
)
if
err
!=
nil
{
log
.
Printf
(
"Increment account wait count failed: %v"
,
err
)
...
...
@@ -244,12 +253,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log
.
Printf
(
"Account wait queue full: account=%d"
,
account
.
ID
)
googleError
(
c
,
http
.
StatusTooManyRequests
,
"Too many pending requests, please retry later"
)
return
}
else
{
// Only set release function if increment succeeded
accountWaitRelease
=
func
()
{
geminiConcurrency
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
if
err
==
nil
&&
canWait
{
accountWaitCounted
=
true
}
defer
func
()
{
if
accountWaitCounted
{
geminiConcurrency
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
}
}()
accountReleaseFunc
,
err
=
geminiConcurrency
.
AcquireAccountSlotWithWaitTimeout
(
c
,
...
...
@@ -260,19 +272,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
&
streamStarted
,
)
if
err
!=
nil
{
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
googleError
(
c
,
http
.
StatusTooManyRequests
,
err
.
Error
())
return
}
if
accountWaitCounted
{
geminiConcurrency
.
DecrementAccountWaitCount
(
c
.
Request
.
Context
(),
account
.
ID
)
accountWaitCounted
=
false
}
if
err
:=
h
.
gatewayService
.
BindStickySession
(
c
.
Request
.
Context
(),
apiKey
.
GroupID
,
sessionKey
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"Bind sticky session failed: %v"
,
err
)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountReleaseFunc
)
accountWaitRelease
=
wrapReleaseOnDone
(
c
.
Request
.
Context
(),
accountWaitRelease
)
// 5) forward (根据平台分流)
var
result
*
service
.
ForwardResult
...
...
@@ -284,9 +296,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if
accountReleaseFunc
!=
nil
{
accountReleaseFunc
()
}
if
accountWaitRelease
!=
nil
{
accountWaitRelease
()
}
if
err
!=
nil
{
var
failoverErr
*
service
.
UpstreamFailoverError
if
errors
.
As
(
err
,
&
failoverErr
)
{
...
...
@@ -306,8 +315,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent
:=
c
.
GetHeader
(
"User-Agent"
)
clientIP
:=
ip
.
GetClientIP
(
c
)
// 6) record usage async
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
string
)
{
go
func
(
result
*
service
.
ForwardResult
,
usedAccount
*
service
.
Account
,
ua
,
ip
string
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
10
*
time
.
Second
)
defer
cancel
()
if
err
:=
h
.
gatewayService
.
RecordUsage
(
ctx
,
&
service
.
RecordUsageInput
{
...
...
@@ -317,10 +330,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
Account
:
usedAccount
,
Subscription
:
subscription
,
UserAgent
:
ua
,
IPAddress
:
ip
,
});
err
!=
nil
{
log
.
Printf
(
"Record usage failed: %v"
,
err
)
}
}(
result
,
account
,
userAgent
)
}(
result
,
account
,
userAgent
,
clientIP
)
return
}
}
...
...
Prev
1
2
3
4
5
6
7
8
9
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment