Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
3d79773b
Commit
3d79773b
authored
Mar 04, 2026
by
kyx236
Browse files
Merge branch 'main' of
https://github.com/james-6-23/sub2api
parents
6aa8cbbf
742e73c9
Changes
253
Show whitespace changes
Inline
Side-by-side
Too many changes to show.
To preserve performance only
253 of 253+
files are displayed.
Plain diff
Email patch
backend/internal/handler/admin/account_data.go
View file @
3d79773b
...
...
@@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
return
}
dataPayload
:=
req
.
Data
if
err
:=
validateDataHeader
(
dataPayload
);
err
!=
nil
{
if
err
:=
validateDataHeader
(
req
.
Data
);
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
executeAdminIdempotentJSON
(
c
,
"admin.accounts.import_data"
,
req
,
service
.
DefaultWriteIdempotencyTTL
(),
func
(
ctx
context
.
Context
)
(
any
,
error
)
{
return
h
.
importData
(
ctx
,
req
)
})
}
func
(
h
*
AccountHandler
)
importData
(
ctx
context
.
Context
,
req
DataImportRequest
)
(
DataImportResult
,
error
)
{
skipDefaultGroupBind
:=
true
if
req
.
SkipDefaultGroupBind
!=
nil
{
skipDefaultGroupBind
=
*
req
.
SkipDefaultGroupBind
}
dataPayload
:=
req
.
Data
result
:=
DataImportResult
{}
existingProxies
,
err
:=
h
.
listAllProxies
(
c
.
Request
.
Context
())
existingProxies
,
err
:=
h
.
listAllProxies
(
ctx
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
return
result
,
err
}
proxyKeyToID
:=
make
(
map
[
string
]
int64
,
len
(
existingProxies
))
...
...
@@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
proxyKeyToID
[
key
]
=
existingID
result
.
ProxyReused
++
if
normalizedStatus
!=
""
{
if
proxy
,
e
rr
:=
h
.
adminService
.
GetProxy
(
c
.
Request
.
Context
()
,
existingID
);
e
rr
==
nil
&&
proxy
!=
nil
&&
proxy
.
Status
!=
normalizedStatus
{
_
,
_
=
h
.
adminService
.
UpdateProxy
(
c
.
Request
.
Context
()
,
existingID
,
&
service
.
UpdateProxyInput
{
if
proxy
,
getE
rr
:=
h
.
adminService
.
GetProxy
(
c
tx
,
existingID
);
getE
rr
==
nil
&&
proxy
!=
nil
&&
proxy
.
Status
!=
normalizedStatus
{
_
,
_
=
h
.
adminService
.
UpdateProxy
(
c
tx
,
existingID
,
&
service
.
UpdateProxyInput
{
Status
:
normalizedStatus
,
})
}
...
...
@@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
continue
}
created
,
e
rr
:=
h
.
adminService
.
CreateProxy
(
c
.
Request
.
Context
()
,
&
service
.
CreateProxyInput
{
created
,
createE
rr
:=
h
.
adminService
.
CreateProxy
(
c
tx
,
&
service
.
CreateProxyInput
{
Name
:
defaultProxyName
(
item
.
Name
),
Protocol
:
item
.
Protocol
,
Host
:
item
.
Host
,
...
...
@@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
Username
:
item
.
Username
,
Password
:
item
.
Password
,
})
if
e
rr
!=
nil
{
if
createE
rr
!=
nil
{
result
.
ProxyFailed
++
result
.
Errors
=
append
(
result
.
Errors
,
DataImportError
{
Kind
:
"proxy"
,
Name
:
item
.
Name
,
ProxyKey
:
key
,
Message
:
e
rr
.
Error
(),
Message
:
createE
rr
.
Error
(),
})
continue
}
...
...
@@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result
.
ProxyCreated
++
if
normalizedStatus
!=
""
&&
normalizedStatus
!=
created
.
Status
{
_
,
_
=
h
.
adminService
.
UpdateProxy
(
c
.
Request
.
Context
()
,
created
.
ID
,
&
service
.
UpdateProxyInput
{
_
,
_
=
h
.
adminService
.
UpdateProxy
(
c
tx
,
created
.
ID
,
&
service
.
UpdateProxyInput
{
Status
:
normalizedStatus
,
})
}
...
...
@@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
SkipDefaultGroupBind
:
skipDefaultGroupBind
,
}
if
_
,
err
:=
h
.
adminService
.
CreateAccount
(
c
.
Request
.
Context
()
,
accountInput
);
err
!=
nil
{
if
_
,
err
:=
h
.
adminService
.
CreateAccount
(
c
tx
,
accountInput
);
err
!=
nil
{
result
.
AccountFailed
++
result
.
Errors
=
append
(
result
.
Errors
,
DataImportError
{
Kind
:
"account"
,
...
...
@@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result
.
AccountCreated
++
}
re
sponse
.
Success
(
c
,
result
)
re
turn
result
,
nil
}
func
(
h
*
AccountHandler
)
listAllProxies
(
ctx
context
.
Context
)
([]
service
.
Proxy
,
error
)
{
...
...
backend/internal/handler/admin/account_data_handler_test.go
View file @
3d79773b
...
...
@@ -64,6 +64,7 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) {
nil
,
nil
,
nil
,
nil
,
)
router
.
GET
(
"/api/v1/admin/accounts/data"
,
h
.
ExportData
)
...
...
backend/internal/handler/admin/account_handler.go
View file @
3d79773b
...
...
@@ -2,7 +2,13 @@
package
admin
import
(
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
...
...
@@ -10,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
...
...
@@ -46,6 +53,7 @@ type AccountHandler struct {
concurrencyService
*
service
.
ConcurrencyService
crsSyncService
*
service
.
CRSSyncService
sessionLimitCache
service
.
SessionLimitCache
rpmCache
service
.
RPMCache
tokenCacheInvalidator
service
.
TokenCacheInvalidator
}
...
...
@@ -62,6 +70,7 @@ func NewAccountHandler(
concurrencyService
*
service
.
ConcurrencyService
,
crsSyncService
*
service
.
CRSSyncService
,
sessionLimitCache
service
.
SessionLimitCache
,
rpmCache
service
.
RPMCache
,
tokenCacheInvalidator
service
.
TokenCacheInvalidator
,
)
*
AccountHandler
{
return
&
AccountHandler
{
...
...
@@ -76,6 +85,7 @@ func NewAccountHandler(
concurrencyService
:
concurrencyService
,
crsSyncService
:
crsSyncService
,
sessionLimitCache
:
sessionLimitCache
,
rpmCache
:
rpmCache
,
tokenCacheInvalidator
:
tokenCacheInvalidator
,
}
}
...
...
@@ -133,6 +143,13 @@ type BulkUpdateAccountsRequest struct {
ConfirmMixedChannelRisk
*
bool
`json:"confirm_mixed_channel_risk"`
// 用户确认混合渠道风险
}
// CheckMixedChannelRequest represents check mixed channel risk request
type
CheckMixedChannelRequest
struct
{
Platform
string
`json:"platform" binding:"required"`
GroupIDs
[]
int64
`json:"group_ids"`
AccountID
*
int64
`json:"account_id"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type
AccountWithConcurrency
struct
{
*
dto
.
Account
...
...
@@ -140,6 +157,51 @@ type AccountWithConcurrency struct {
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost
*
float64
`json:"current_window_cost,omitempty"`
// 当前窗口费用
ActiveSessions
*
int
`json:"active_sessions,omitempty"`
// 当前活跃会话数
CurrentRPM
*
int
`json:"current_rpm,omitempty"`
// 当前分钟 RPM 计数
}
func
(
h
*
AccountHandler
)
buildAccountResponseWithRuntime
(
ctx
context
.
Context
,
account
*
service
.
Account
)
AccountWithConcurrency
{
item
:=
AccountWithConcurrency
{
Account
:
dto
.
AccountFromService
(
account
),
CurrentConcurrency
:
0
,
}
if
account
==
nil
{
return
item
}
if
h
.
concurrencyService
!=
nil
{
if
counts
,
err
:=
h
.
concurrencyService
.
GetAccountConcurrencyBatch
(
ctx
,
[]
int64
{
account
.
ID
});
err
==
nil
{
item
.
CurrentConcurrency
=
counts
[
account
.
ID
]
}
}
if
account
.
IsAnthropicOAuthOrSetupToken
()
{
if
h
.
accountUsageService
!=
nil
&&
account
.
GetWindowCostLimit
()
>
0
{
startTime
:=
account
.
GetCurrentWindowStartTime
()
if
stats
,
err
:=
h
.
accountUsageService
.
GetAccountWindowStats
(
ctx
,
account
.
ID
,
startTime
);
err
==
nil
&&
stats
!=
nil
{
cost
:=
stats
.
StandardCost
item
.
CurrentWindowCost
=
&
cost
}
}
if
h
.
sessionLimitCache
!=
nil
&&
account
.
GetMaxSessions
()
>
0
{
idleTimeout
:=
time
.
Duration
(
account
.
GetSessionIdleTimeoutMinutes
())
*
time
.
Minute
idleTimeouts
:=
map
[
int64
]
time
.
Duration
{
account
.
ID
:
idleTimeout
}
if
sessions
,
err
:=
h
.
sessionLimitCache
.
GetActiveSessionCountBatch
(
ctx
,
[]
int64
{
account
.
ID
},
idleTimeouts
);
err
==
nil
{
if
count
,
ok
:=
sessions
[
account
.
ID
];
ok
{
item
.
ActiveSessions
=
&
count
}
}
}
if
h
.
rpmCache
!=
nil
&&
account
.
GetBaseRPM
()
>
0
{
if
rpm
,
err
:=
h
.
rpmCache
.
GetRPM
(
ctx
,
account
.
ID
);
err
==
nil
{
item
.
CurrentRPM
=
&
rpm
}
}
}
return
item
}
// List handles listing all accounts with pagination
...
...
@@ -155,6 +217,7 @@ func (h *AccountHandler) List(c *gin.Context) {
if
len
(
search
)
>
100
{
search
=
search
[
:
100
]
}
lite
:=
parseBoolQueryWithDefault
(
c
.
Query
(
"lite"
),
false
)
var
groupID
int64
if
groupIDStr
:=
c
.
Query
(
"group"
);
groupIDStr
!=
""
{
...
...
@@ -173,15 +236,21 @@ func (h *AccountHandler) List(c *gin.Context) {
accountIDs
[
i
]
=
acc
.
ID
}
concurrencyCounts
,
err
:=
h
.
concurrencyService
.
GetAccountConcurrencyBatch
(
c
.
Request
.
Context
(),
accountIDs
)
if
err
!=
nil
{
// Log error but don't fail the request, just use 0 for all
concurrencyCounts
=
make
(
map
[
int64
]
int
)
concurrencyCounts
:=
make
(
map
[
int64
]
int
)
var
windowCosts
map
[
int64
]
float64
var
activeSessions
map
[
int64
]
int
var
rpmCounts
map
[
int64
]
int
if
!
lite
{
// Get current concurrency counts for all accounts
if
h
.
concurrencyService
!=
nil
{
if
cc
,
ccErr
:=
h
.
concurrencyService
.
GetAccountConcurrencyBatch
(
c
.
Request
.
Context
(),
accountIDs
);
ccErr
==
nil
&&
cc
!=
nil
{
concurrencyCounts
=
cc
}
// 识别需要查询窗口费用
和
会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
}
// 识别需要查询窗口费用
、
会话数
和 RPM
的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs
:=
make
([]
int64
,
0
)
sessionLimitAccountIDs
:=
make
([]
int64
,
0
)
rpmAccountIDs
:=
make
([]
int64
,
0
)
sessionIdleTimeouts
:=
make
(
map
[
int64
]
time
.
Duration
)
// 各账号的会话空闲超时配置
for
i
:=
range
accounts
{
acc
:=
&
accounts
[
i
]
...
...
@@ -193,12 +262,19 @@ func (h *AccountHandler) List(c *gin.Context) {
sessionLimitAccountIDs
=
append
(
sessionLimitAccountIDs
,
acc
.
ID
)
sessionIdleTimeouts
[
acc
.
ID
]
=
time
.
Duration
(
acc
.
GetSessionIdleTimeoutMinutes
())
*
time
.
Minute
}
if
acc
.
GetBaseRPM
()
>
0
{
rpmAccountIDs
=
append
(
rpmAccountIDs
,
acc
.
ID
)
}
}
}
// 并行获取窗口费用和活跃会话数
var
windowCosts
map
[
int64
]
float64
var
activeSessions
map
[
int64
]
int
// 获取 RPM 计数(批量查询)
if
len
(
rpmAccountIDs
)
>
0
&&
h
.
rpmCache
!=
nil
{
rpmCounts
,
_
=
h
.
rpmCache
.
GetRPMBatch
(
c
.
Request
.
Context
(),
rpmAccountIDs
)
if
rpmCounts
==
nil
{
rpmCounts
=
make
(
map
[
int64
]
int
)
}
}
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
if
len
(
sessionLimitAccountIDs
)
>
0
&&
h
.
sessionLimitCache
!=
nil
{
...
...
@@ -235,6 +311,7 @@ func (h *AccountHandler) List(c *gin.Context) {
}
_
=
g
.
Wait
()
}
}
// Build response with concurrency info
result
:=
make
([]
AccountWithConcurrency
,
len
(
accounts
))
...
...
@@ -259,12 +336,84 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
// 添加 RPM 计数(仅当启用时)
if
rpmCounts
!=
nil
{
if
rpm
,
ok
:=
rpmCounts
[
acc
.
ID
];
ok
{
item
.
CurrentRPM
=
&
rpm
}
}
result
[
i
]
=
item
}
etag
:=
buildAccountsListETag
(
result
,
total
,
page
,
pageSize
,
platform
,
accountType
,
status
,
search
,
lite
)
if
etag
!=
""
{
c
.
Header
(
"ETag"
,
etag
)
c
.
Header
(
"Vary"
,
"If-None-Match"
)
if
ifNoneMatchMatched
(
c
.
GetHeader
(
"If-None-Match"
),
etag
)
{
c
.
Status
(
http
.
StatusNotModified
)
return
}
}
response
.
Paginated
(
c
,
result
,
total
,
page
,
pageSize
)
}
func
buildAccountsListETag
(
items
[]
AccountWithConcurrency
,
total
int64
,
page
,
pageSize
int
,
platform
,
accountType
,
status
,
search
string
,
lite
bool
,
)
string
{
payload
:=
struct
{
Total
int64
`json:"total"`
Page
int
`json:"page"`
PageSize
int
`json:"page_size"`
Platform
string
`json:"platform"`
AccountType
string
`json:"type"`
Status
string
`json:"status"`
Search
string
`json:"search"`
Lite
bool
`json:"lite"`
Items
[]
AccountWithConcurrency
`json:"items"`
}{
Total
:
total
,
Page
:
page
,
PageSize
:
pageSize
,
Platform
:
platform
,
AccountType
:
accountType
,
Status
:
status
,
Search
:
search
,
Lite
:
lite
,
Items
:
items
,
}
raw
,
err
:=
json
.
Marshal
(
payload
)
if
err
!=
nil
{
return
""
}
sum
:=
sha256
.
Sum256
(
raw
)
return
"
\"
"
+
hex
.
EncodeToString
(
sum
[
:
])
+
"
\"
"
}
func
ifNoneMatchMatched
(
ifNoneMatch
,
etag
string
)
bool
{
if
etag
==
""
||
ifNoneMatch
==
""
{
return
false
}
for
_
,
token
:=
range
strings
.
Split
(
ifNoneMatch
,
","
)
{
candidate
:=
strings
.
TrimSpace
(
token
)
if
candidate
==
"*"
{
return
true
}
if
candidate
==
etag
{
return
true
}
if
strings
.
HasPrefix
(
candidate
,
"W/"
)
&&
strings
.
TrimPrefix
(
candidate
,
"W/"
)
==
etag
{
return
true
}
}
return
false
}
// GetByID handles getting an account by ID
// GET /api/v1/admin/accounts/:id
func
(
h
*
AccountHandler
)
GetByID
(
c
*
gin
.
Context
)
{
...
...
@@ -280,7 +429,51 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
return
}
response
.
Success
(
c
,
dto
.
AccountFromService
(
account
))
response
.
Success
(
c
,
h
.
buildAccountResponseWithRuntime
(
c
.
Request
.
Context
(),
account
))
}
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
// POST /api/v1/admin/accounts/check-mixed-channel
func
(
h
*
AccountHandler
)
CheckMixedChannel
(
c
*
gin
.
Context
)
{
var
req
CheckMixedChannelRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
len
(
req
.
GroupIDs
)
==
0
{
response
.
Success
(
c
,
gin
.
H
{
"has_risk"
:
false
})
return
}
accountID
:=
int64
(
0
)
if
req
.
AccountID
!=
nil
{
accountID
=
*
req
.
AccountID
}
err
:=
h
.
adminService
.
CheckMixedChannelRisk
(
c
.
Request
.
Context
(),
accountID
,
req
.
Platform
,
req
.
GroupIDs
)
if
err
!=
nil
{
var
mixedErr
*
service
.
MixedChannelError
if
errors
.
As
(
err
,
&
mixedErr
)
{
response
.
Success
(
c
,
gin
.
H
{
"has_risk"
:
true
,
"error"
:
"mixed_channel_warning"
,
"message"
:
mixedErr
.
Error
(),
"details"
:
gin
.
H
{
"group_id"
:
mixedErr
.
GroupID
,
"group_name"
:
mixedErr
.
GroupName
,
"current_platform"
:
mixedErr
.
CurrentPlatform
,
"other_platform"
:
mixedErr
.
OtherPlatform
,
},
})
return
}
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"has_risk"
:
false
})
}
// Create handles creating a new account
...
...
@@ -295,11 +488,14 @@ func (h *AccountHandler) Create(c *gin.Context) {
response
.
BadRequest
(
c
,
"rate_multiplier must be >= 0"
)
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM
(
req
.
Extra
)
// 确定是否跳过混合渠道检查
skipCheck
:=
req
.
ConfirmMixedChannelRisk
!=
nil
&&
*
req
.
ConfirmMixedChannelRisk
account
,
err
:=
h
.
adminService
.
CreateAccount
(
c
.
Request
.
Context
(),
&
service
.
CreateAccountInput
{
result
,
err
:=
executeAdminIdempotent
(
c
,
"admin.accounts.create"
,
req
,
service
.
DefaultWriteIdempotencyTTL
(),
func
(
ctx
context
.
Context
)
(
any
,
error
)
{
account
,
execErr
:=
h
.
adminService
.
CreateAccount
(
ctx
,
&
service
.
CreateAccountInput
{
Name
:
req
.
Name
,
Notes
:
req
.
Notes
,
Platform
:
req
.
Platform
,
...
...
@@ -315,30 +511,34 @@ func (h *AccountHandler) Create(c *gin.Context) {
AutoPauseOnExpired
:
req
.
AutoPauseOnExpired
,
SkipMixedChannelCheck
:
skipCheck
,
})
if
execErr
!=
nil
{
return
nil
,
execErr
}
return
h
.
buildAccountResponseWithRuntime
(
ctx
,
account
),
nil
})
if
err
!=
nil
{
// 检查是否为混合渠道错误
var
mixedErr
*
service
.
MixedChannelError
if
errors
.
As
(
err
,
&
mixedErr
)
{
//
返回特殊错误码要求确认
//
创建接口仅返回最小必要字段,详细信息由专门检查接口提供
c
.
JSON
(
409
,
gin
.
H
{
"error"
:
"mixed_channel_warning"
,
"message"
:
mixedErr
.
Error
(),
"details"
:
gin
.
H
{
"group_id"
:
mixedErr
.
GroupID
,
"group_name"
:
mixedErr
.
GroupName
,
"current_platform"
:
mixedErr
.
CurrentPlatform
,
"other_platform"
:
mixedErr
.
OtherPlatform
,
},
"require_confirmation"
:
true
,
})
return
}
if
retryAfter
:=
service
.
RetryAfterSecondsFromError
(
err
);
retryAfter
>
0
{
c
.
Header
(
"Retry-After"
,
strconv
.
Itoa
(
retryAfter
))
}
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
dto
.
AccountFromService
(
account
))
if
result
!=
nil
&&
result
.
Replayed
{
c
.
Header
(
"X-Idempotency-Replayed"
,
"true"
)
}
response
.
Success
(
c
,
result
.
Data
)
}
// Update handles updating an account
...
...
@@ -359,6 +559,8 @@ func (h *AccountHandler) Update(c *gin.Context) {
response
.
BadRequest
(
c
,
"rate_multiplier must be >= 0"
)
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM
(
req
.
Extra
)
// 确定是否跳过混合渠道检查
skipCheck
:=
req
.
ConfirmMixedChannelRisk
!=
nil
&&
*
req
.
ConfirmMixedChannelRisk
...
...
@@ -383,17 +585,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
// 检查是否为混合渠道错误
var
mixedErr
*
service
.
MixedChannelError
if
errors
.
As
(
err
,
&
mixedErr
)
{
//
返回特殊错误码要求确认
//
更新接口仅返回最小必要字段,详细信息由专门检查接口提供
c
.
JSON
(
409
,
gin
.
H
{
"error"
:
"mixed_channel_warning"
,
"message"
:
mixedErr
.
Error
(),
"details"
:
gin
.
H
{
"group_id"
:
mixedErr
.
GroupID
,
"group_name"
:
mixedErr
.
GroupName
,
"current_platform"
:
mixedErr
.
CurrentPlatform
,
"other_platform"
:
mixedErr
.
OtherPlatform
,
},
"require_confirmation"
:
true
,
})
return
}
...
...
@@ -402,7 +597,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
return
}
response
.
Success
(
c
,
dto
.
AccountFromService
(
account
))
response
.
Success
(
c
,
h
.
buildAccountResponseWithRuntime
(
c
.
Request
.
Context
(),
account
))
}
// Delete handles deleting an account
...
...
@@ -660,7 +855,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
response
.
Success
(
c
,
dto
.
AccountFromService
(
updatedAccount
))
response
.
Success
(
c
,
h
.
buildAccountResponseWithRuntime
(
c
.
Request
.
Context
(),
updatedAccount
))
}
// GetStats handles getting account statistics
...
...
@@ -718,7 +913,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
}
}
response
.
Success
(
c
,
dto
.
AccountFromService
(
account
))
response
.
Success
(
c
,
h
.
buildAccountResponseWithRuntime
(
c
.
Request
.
Context
(),
account
))
}
// BatchCreate handles batch creating accounts
...
...
@@ -732,7 +927,7 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
return
}
ctx
:=
c
.
Request
.
Context
()
executeAdminIdempotentJSON
(
c
,
"admin.accounts.batch_create"
,
req
,
service
.
DefaultWriteIdempotencyTTL
(),
func
(
ctx
context
.
Context
)
(
any
,
error
)
{
success
:=
0
failed
:=
0
results
:=
make
([]
gin
.
H
,
0
,
len
(
req
.
Accounts
))
...
...
@@ -748,6 +943,9 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
continue
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM
(
item
.
Extra
)
skipCheck
:=
item
.
ConfirmMixedChannelRisk
!=
nil
&&
*
item
.
ConfirmMixedChannelRisk
account
,
err
:=
h
.
adminService
.
CreateAccount
(
ctx
,
&
service
.
CreateAccountInput
{
...
...
@@ -783,10 +981,11 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
})
}
re
sponse
.
Success
(
c
,
gin
.
H
{
re
turn
gin
.
H
{
"success"
:
success
,
"failed"
:
failed
,
"results"
:
results
,
},
nil
})
}
...
...
@@ -824,49 +1023,48 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
}
ctx
:=
c
.
Request
.
Context
()
success
:=
0
failed
:=
0
results
:=
[]
gin
.
H
{}
// 阶段一:预验证所有账号存在,收集 credentials
type
accountUpdate
struct
{
ID
int64
Credentials
map
[
string
]
any
}
updates
:=
make
([]
accountUpdate
,
0
,
len
(
req
.
AccountIDs
))
for
_
,
accountID
:=
range
req
.
AccountIDs
{
// Get account
account
,
err
:=
h
.
adminService
.
GetAccount
(
ctx
,
accountID
)
if
err
!=
nil
{
failed
++
results
=
append
(
results
,
gin
.
H
{
"account_id"
:
accountID
,
"success"
:
false
,
"error"
:
"Account not found"
,
})
continue
response
.
Error
(
c
,
404
,
fmt
.
Sprintf
(
"Account %d not found"
,
accountID
))
return
}
// Update credentials field
if
account
.
Credentials
==
nil
{
account
.
Credentials
=
make
(
map
[
string
]
any
)
}
account
.
Credentials
[
req
.
Field
]
=
req
.
Value
// Update account
updateInput
:=
&
service
.
UpdateAccountInput
{
Credentials
:
account
.
Credentials
,
updates
=
append
(
updates
,
accountUpdate
{
ID
:
accountID
,
Credentials
:
account
.
Credentials
})
}
_
,
err
=
h
.
adminService
.
UpdateAccount
(
ctx
,
accountID
,
updateInput
)
if
err
!=
nil
{
// 阶段二:依次更新,返回每个账号的成功/失败明细,便于调用方重试
success
:=
0
failed
:=
0
successIDs
:=
make
([]
int64
,
0
,
len
(
updates
))
failedIDs
:=
make
([]
int64
,
0
,
len
(
updates
))
results
:=
make
([]
gin
.
H
,
0
,
len
(
updates
))
for
_
,
u
:=
range
updates
{
updateInput
:=
&
service
.
UpdateAccountInput
{
Credentials
:
u
.
Credentials
}
if
_
,
err
:=
h
.
adminService
.
UpdateAccount
(
ctx
,
u
.
ID
,
updateInput
);
err
!=
nil
{
failed
++
failedIDs
=
append
(
failedIDs
,
u
.
ID
)
results
=
append
(
results
,
gin
.
H
{
"account_id"
:
account
ID
,
"account_id"
:
u
.
ID
,
"success"
:
false
,
"error"
:
err
.
Error
(),
})
continue
}
success
++
successIDs
=
append
(
successIDs
,
u
.
ID
)
results
=
append
(
results
,
gin
.
H
{
"account_id"
:
account
ID
,
"account_id"
:
u
.
ID
,
"success"
:
true
,
})
}
...
...
@@ -874,6 +1072,8 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
response
.
Success
(
c
,
gin
.
H
{
"success"
:
success
,
"failed"
:
failed
,
"success_ids"
:
successIDs
,
"failed_ids"
:
failedIDs
,
"results"
:
results
,
})
}
...
...
@@ -890,6 +1090,8 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response
.
BadRequest
(
c
,
"rate_multiplier must be >= 0"
)
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM
(
req
.
Extra
)
// 确定是否跳过混合渠道检查
skipCheck
:=
req
.
ConfirmMixedChannelRisk
!=
nil
&&
*
req
.
ConfirmMixedChannelRisk
...
...
@@ -925,6 +1127,14 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
SkipMixedChannelCheck
:
skipCheck
,
})
if
err
!=
nil
{
var
mixedErr
*
service
.
MixedChannelError
if
errors
.
As
(
err
,
&
mixedErr
)
{
c
.
JSON
(
409
,
gin
.
H
{
"error"
:
"mixed_channel_warning"
,
"message"
:
mixedErr
.
Error
(),
})
return
}
response
.
ErrorFrom
(
c
,
err
)
return
}
...
...
@@ -1109,7 +1319,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
return
}
response
.
Success
(
c
,
gin
.
H
{
"message"
:
"Rate limit cleared successfully"
})
account
,
err
:=
h
.
adminService
.
GetAccount
(
c
.
Request
.
Context
(),
accountID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
h
.
buildAccountResponseWithRuntime
(
c
.
Request
.
Context
(),
account
))
}
// GetTempUnschedulable handles getting temporary unschedulable status
...
...
@@ -1173,6 +1389,57 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
response
.
Success
(
c
,
stats
)
}
// BatchTodayStatsRequest 批量今日统计请求体。
type
BatchTodayStatsRequest
struct
{
AccountIDs
[]
int64
`json:"account_ids" binding:"required"`
}
// GetBatchTodayStats 批量获取多个账号的今日统计。
// POST /api/v1/admin/accounts/today-stats/batch
func
(
h
*
AccountHandler
)
GetBatchTodayStats
(
c
*
gin
.
Context
)
{
var
req
BatchTodayStatsRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
accountIDs
:=
normalizeInt64IDList
(
req
.
AccountIDs
)
if
len
(
accountIDs
)
==
0
{
response
.
Success
(
c
,
gin
.
H
{
"stats"
:
map
[
string
]
any
{}})
return
}
cacheKey
:=
buildAccountTodayStatsBatchCacheKey
(
accountIDs
)
if
cached
,
ok
:=
accountTodayStatsBatchCache
.
Get
(
cacheKey
);
ok
{
if
cached
.
ETag
!=
""
{
c
.
Header
(
"ETag"
,
cached
.
ETag
)
c
.
Header
(
"Vary"
,
"If-None-Match"
)
if
ifNoneMatchMatched
(
c
.
GetHeader
(
"If-None-Match"
),
cached
.
ETag
)
{
c
.
Status
(
http
.
StatusNotModified
)
return
}
}
c
.
Header
(
"X-Snapshot-Cache"
,
"hit"
)
response
.
Success
(
c
,
cached
.
Payload
)
return
}
stats
,
err
:=
h
.
accountUsageService
.
GetTodayStatsBatch
(
c
.
Request
.
Context
(),
accountIDs
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
payload
:=
gin
.
H
{
"stats"
:
stats
}
cached
:=
accountTodayStatsBatchCache
.
Set
(
cacheKey
,
payload
)
if
cached
.
ETag
!=
""
{
c
.
Header
(
"ETag"
,
cached
.
ETag
)
c
.
Header
(
"Vary"
,
"If-None-Match"
)
}
c
.
Header
(
"X-Snapshot-Cache"
,
"miss"
)
response
.
Success
(
c
,
payload
)
}
// SetSchedulableRequest represents the request body for setting schedulable status
type
SetSchedulableRequest
struct
{
Schedulable
bool
`json:"schedulable"`
...
...
@@ -1199,7 +1466,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
return
}
response
.
Success
(
c
,
dto
.
AccountFromService
(
account
))
response
.
Success
(
c
,
h
.
buildAccountResponseWithRuntime
(
c
.
Request
.
Context
(),
account
))
}
// GetAvailableModels handles getting available models for an account
...
...
@@ -1296,32 +1563,14 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
// Handle Antigravity accounts: return Claude + Gemini models
if
account
.
Platform
==
service
.
PlatformAntigravity
{
// Antigravity 支持 Claude 和部分 Gemini 模型
type
UnifiedModel
struct
{
ID
string
`json:"id"`
Type
string
`json:"type"`
DisplayName
string
`json:"display_name"`
}
var
models
[]
UnifiedModel
// 添加 Claude 模型
for
_
,
m
:=
range
claude
.
DefaultModels
{
models
=
append
(
models
,
UnifiedModel
{
ID
:
m
.
ID
,
Type
:
m
.
Type
,
DisplayName
:
m
.
DisplayName
,
})
}
// 添加 Gemini 3 系列模型用于测试
geminiTestModels
:=
[]
UnifiedModel
{
{
ID
:
"gemini-3-flash"
,
Type
:
"model"
,
DisplayName
:
"Gemini 3 Flash"
},
{
ID
:
"gemini-3-pro-preview"
,
Type
:
"model"
,
DisplayName
:
"Gemini 3 Pro Preview"
},
// 直接复用 antigravity.DefaultModels(),与 /v1/models 端点保持同步
response
.
Success
(
c
,
antigravity
.
DefaultModels
())
return
}
models
=
append
(
models
,
geminiTestModels
...
)
response
.
Success
(
c
,
models
)
// Handle Sora accounts
if
account
.
Platform
==
service
.
PlatformSora
{
response
.
Success
(
c
,
service
.
DefaultSoraModels
(
nil
))
return
}
...
...
@@ -1532,3 +1781,22 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
func
(
h
*
AccountHandler
)
GetAntigravityDefaultModelMapping
(
c
*
gin
.
Context
)
{
response
.
Success
(
c
,
domain
.
DefaultAntigravityModelMapping
)
}
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
func
sanitizeExtraBaseRPM
(
extra
map
[
string
]
any
)
{
if
extra
==
nil
{
return
}
raw
,
ok
:=
extra
[
"base_rpm"
]
if
!
ok
{
return
}
v
:=
service
.
ParseExtraInt
(
raw
)
if
v
<
0
{
v
=
0
}
else
if
v
>
10000
{
v
=
10000
}
extra
[
"base_rpm"
]
=
v
}
backend/internal/handler/admin/account_handler_mixed_channel_test.go
0 → 100644
View file @
3d79773b
package
admin
import
(
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
setupAccountMixedChannelRouter
(
adminSvc
*
stubAdminService
)
*
gin
.
Engine
{
gin
.
SetMode
(
gin
.
TestMode
)
router
:=
gin
.
New
()
accountHandler
:=
NewAccountHandler
(
adminSvc
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
router
.
POST
(
"/api/v1/admin/accounts/check-mixed-channel"
,
accountHandler
.
CheckMixedChannel
)
router
.
POST
(
"/api/v1/admin/accounts"
,
accountHandler
.
Create
)
router
.
PUT
(
"/api/v1/admin/accounts/:id"
,
accountHandler
.
Update
)
router
.
POST
(
"/api/v1/admin/accounts/bulk-update"
,
accountHandler
.
BulkUpdate
)
return
router
}
func
TestAccountHandlerCheckMixedChannelNoRisk
(
t
*
testing
.
T
)
{
adminSvc
:=
newStubAdminService
()
router
:=
setupAccountMixedChannelRouter
(
adminSvc
)
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"platform"
:
"antigravity"
,
"group_ids"
:
[]
int64
{
27
},
})
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts/check-mixed-channel"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
float64
(
0
),
resp
[
"code"
])
data
,
ok
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
false
,
data
[
"has_risk"
])
require
.
Equal
(
t
,
int64
(
0
),
adminSvc
.
lastMixedCheck
.
accountID
)
require
.
Equal
(
t
,
"antigravity"
,
adminSvc
.
lastMixedCheck
.
platform
)
require
.
Equal
(
t
,
[]
int64
{
27
},
adminSvc
.
lastMixedCheck
.
groupIDs
)
}
func
TestAccountHandlerCheckMixedChannelWithRisk
(
t
*
testing
.
T
)
{
adminSvc
:=
newStubAdminService
()
adminSvc
.
checkMixedErr
=
&
service
.
MixedChannelError
{
GroupID
:
27
,
GroupName
:
"claude-max"
,
CurrentPlatform
:
"Antigravity"
,
OtherPlatform
:
"Anthropic"
,
}
router
:=
setupAccountMixedChannelRouter
(
adminSvc
)
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"platform"
:
"antigravity"
,
"group_ids"
:
[]
int64
{
27
},
"account_id"
:
99
,
})
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts/check-mixed-channel"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
float64
(
0
),
resp
[
"code"
])
data
,
ok
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
true
,
data
[
"has_risk"
])
require
.
Equal
(
t
,
"mixed_channel_warning"
,
data
[
"error"
])
details
,
ok
:=
data
[
"details"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
float64
(
27
),
details
[
"group_id"
])
require
.
Equal
(
t
,
"claude-max"
,
details
[
"group_name"
])
require
.
Equal
(
t
,
"Antigravity"
,
details
[
"current_platform"
])
require
.
Equal
(
t
,
"Anthropic"
,
details
[
"other_platform"
])
require
.
Equal
(
t
,
int64
(
99
),
adminSvc
.
lastMixedCheck
.
accountID
)
}
func
TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse
(
t
*
testing
.
T
)
{
adminSvc
:=
newStubAdminService
()
adminSvc
.
createAccountErr
=
&
service
.
MixedChannelError
{
GroupID
:
27
,
GroupName
:
"claude-max"
,
CurrentPlatform
:
"Antigravity"
,
OtherPlatform
:
"Anthropic"
,
}
router
:=
setupAccountMixedChannelRouter
(
adminSvc
)
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"name"
:
"ag-oauth-1"
,
"platform"
:
"antigravity"
,
"type"
:
"oauth"
,
"credentials"
:
map
[
string
]
any
{
"refresh_token"
:
"rt"
},
"group_ids"
:
[]
int64
{
27
},
})
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
"mixed_channel_warning"
,
resp
[
"error"
])
require
.
Contains
(
t
,
resp
[
"message"
],
"mixed_channel_warning"
)
_
,
hasDetails
:=
resp
[
"details"
]
_
,
hasRequireConfirmation
:=
resp
[
"require_confirmation"
]
require
.
False
(
t
,
hasDetails
)
require
.
False
(
t
,
hasRequireConfirmation
)
}
func
TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse
(
t
*
testing
.
T
)
{
adminSvc
:=
newStubAdminService
()
adminSvc
.
updateAccountErr
=
&
service
.
MixedChannelError
{
GroupID
:
27
,
GroupName
:
"claude-max"
,
CurrentPlatform
:
"Antigravity"
,
OtherPlatform
:
"Anthropic"
,
}
router
:=
setupAccountMixedChannelRouter
(
adminSvc
)
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"group_ids"
:
[]
int64
{
27
},
})
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPut
,
"/api/v1/admin/accounts/3"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
"mixed_channel_warning"
,
resp
[
"error"
])
require
.
Contains
(
t
,
resp
[
"message"
],
"mixed_channel_warning"
)
_
,
hasDetails
:=
resp
[
"details"
]
_
,
hasRequireConfirmation
:=
resp
[
"require_confirmation"
]
require
.
False
(
t
,
hasDetails
)
require
.
False
(
t
,
hasRequireConfirmation
)
}
func
TestAccountHandlerBulkUpdateMixedChannelConflict
(
t
*
testing
.
T
)
{
adminSvc
:=
newStubAdminService
()
adminSvc
.
bulkUpdateAccountErr
=
&
service
.
MixedChannelError
{
GroupID
:
27
,
GroupName
:
"claude-max"
,
CurrentPlatform
:
"Antigravity"
,
OtherPlatform
:
"Anthropic"
,
}
router
:=
setupAccountMixedChannelRouter
(
adminSvc
)
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"account_ids"
:
[]
int64
{
1
,
2
,
3
},
"group_ids"
:
[]
int64
{
27
},
})
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts/bulk-update"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
rec
.
Code
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
"mixed_channel_warning"
,
resp
[
"error"
])
require
.
Contains
(
t
,
resp
[
"message"
],
"claude-max"
)
}
func
TestAccountHandlerBulkUpdateMixedChannelConfirmSkips
(
t
*
testing
.
T
)
{
adminSvc
:=
newStubAdminService
()
router
:=
setupAccountMixedChannelRouter
(
adminSvc
)
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"account_ids"
:
[]
int64
{
1
,
2
},
"group_ids"
:
[]
int64
{
27
},
"confirm_mixed_channel_risk"
:
true
,
})
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts/bulk-update"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
float64
(
0
),
resp
[
"code"
])
data
,
ok
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
float64
(
2
),
data
[
"success"
])
require
.
Equal
(
t
,
float64
(
0
),
data
[
"failed"
])
}
backend/internal/handler/admin/account_handler_passthrough_test.go
0 → 100644
View file @
3d79773b
package
admin
import
(
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
adminSvc
:=
newStubAdminService
()
handler
:=
NewAccountHandler
(
adminSvc
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
)
router
:=
gin
.
New
()
router
.
POST
(
"/api/v1/admin/accounts"
,
handler
.
Create
)
body
:=
map
[
string
]
any
{
"name"
:
"anthropic-key-1"
,
"platform"
:
"anthropic"
,
"type"
:
"apikey"
,
"credentials"
:
map
[
string
]
any
{
"api_key"
:
"sk-ant-xxx"
,
"base_url"
:
"https://api.anthropic.com"
,
},
"extra"
:
map
[
string
]
any
{
"anthropic_passthrough"
:
true
,
},
"concurrency"
:
1
,
"priority"
:
1
,
}
raw
,
err
:=
json
.
Marshal
(
body
)
require
.
NoError
(
t
,
err
)
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/accounts"
,
bytes
.
NewReader
(
raw
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Len
(
t
,
adminSvc
.
createdAccounts
,
1
)
created
:=
adminSvc
.
createdAccounts
[
0
]
require
.
Equal
(
t
,
"anthropic"
,
created
.
Platform
)
require
.
Equal
(
t
,
"apikey"
,
created
.
Type
)
require
.
NotNil
(
t
,
created
.
Extra
)
require
.
Equal
(
t
,
true
,
created
.
Extra
[
"anthropic_passthrough"
])
}
backend/internal/handler/admin/account_today_stats_cache.go
0 → 100644
View file @
3d79773b
package
admin
import
(
"strconv"
"strings"
"time"
)
var
accountTodayStatsBatchCache
=
newSnapshotCache
(
30
*
time
.
Second
)
func
buildAccountTodayStatsBatchCacheKey
(
accountIDs
[]
int64
)
string
{
if
len
(
accountIDs
)
==
0
{
return
"accounts_today_stats_empty"
}
var
b
strings
.
Builder
b
.
Grow
(
len
(
accountIDs
)
*
6
)
_
,
_
=
b
.
WriteString
(
"accounts_today_stats:"
)
for
i
,
id
:=
range
accountIDs
{
if
i
>
0
{
_
=
b
.
WriteByte
(
','
)
}
_
,
_
=
b
.
WriteString
(
strconv
.
FormatInt
(
id
,
10
))
}
return
b
.
String
()
}
backend/internal/handler/admin/admin_basic_handlers_test.go
View file @
3d79773b
...
...
@@ -19,7 +19,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
userHandler
:=
NewUserHandler
(
adminSvc
,
nil
)
groupHandler
:=
NewGroupHandler
(
adminSvc
)
proxyHandler
:=
NewProxyHandler
(
adminSvc
)
redeemHandler
:=
NewRedeemHandler
(
adminSvc
)
redeemHandler
:=
NewRedeemHandler
(
adminSvc
,
nil
)
router
.
GET
(
"/api/v1/admin/users"
,
userHandler
.
List
)
router
.
GET
(
"/api/v1/admin/users/:id"
,
userHandler
.
GetByID
)
...
...
@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router
.
DELETE
(
"/api/v1/admin/proxies/:id"
,
proxyHandler
.
Delete
)
router
.
POST
(
"/api/v1/admin/proxies/batch-delete"
,
proxyHandler
.
BatchDelete
)
router
.
POST
(
"/api/v1/admin/proxies/:id/test"
,
proxyHandler
.
Test
)
router
.
POST
(
"/api/v1/admin/proxies/:id/quality-check"
,
proxyHandler
.
CheckQuality
)
router
.
GET
(
"/api/v1/admin/proxies/:id/stats"
,
proxyHandler
.
GetStats
)
router
.
GET
(
"/api/v1/admin/proxies/:id/accounts"
,
proxyHandler
.
GetProxyAccounts
)
...
...
@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
rec
=
httptest
.
NewRecorder
()
req
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/admin/proxies/4/quality-check"
,
nil
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
rec
=
httptest
.
NewRecorder
()
req
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/admin/proxies/4/stats"
,
nil
)
router
.
ServeHTTP
(
rec
,
req
)
...
...
backend/internal/handler/admin/admin_helpers_test.go
View file @
3d79773b
...
...
@@ -58,6 +58,96 @@ func TestParseOpsDuration(t *testing.T) {
require
.
False
(
t
,
ok
)
}
func
TestParseOpsOpenAITokenStatsDuration
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
input
string
want
time
.
Duration
ok
bool
}{
{
input
:
"30m"
,
want
:
30
*
time
.
Minute
,
ok
:
true
},
{
input
:
"1h"
,
want
:
time
.
Hour
,
ok
:
true
},
{
input
:
"1d"
,
want
:
24
*
time
.
Hour
,
ok
:
true
},
{
input
:
"15d"
,
want
:
15
*
24
*
time
.
Hour
,
ok
:
true
},
{
input
:
"30d"
,
want
:
30
*
24
*
time
.
Hour
,
ok
:
true
},
{
input
:
"7d"
,
want
:
0
,
ok
:
false
},
}
for
_
,
tt
:=
range
tests
{
got
,
ok
:=
parseOpsOpenAITokenStatsDuration
(
tt
.
input
)
require
.
Equal
(
t
,
tt
.
ok
,
ok
,
"input=%s"
,
tt
.
input
)
require
.
Equal
(
t
,
tt
.
want
,
got
,
"input=%s"
,
tt
.
input
)
}
}
func
TestParseOpsOpenAITokenStatsFilter_Defaults
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/"
,
nil
)
before
:=
time
.
Now
()
.
UTC
()
filter
,
err
:=
parseOpsOpenAITokenStatsFilter
(
c
)
after
:=
time
.
Now
()
.
UTC
()
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
filter
)
require
.
Equal
(
t
,
"30d"
,
filter
.
TimeRange
)
require
.
Equal
(
t
,
1
,
filter
.
Page
)
require
.
Equal
(
t
,
20
,
filter
.
PageSize
)
require
.
Equal
(
t
,
0
,
filter
.
TopN
)
require
.
Nil
(
t
,
filter
.
GroupID
)
require
.
Equal
(
t
,
""
,
filter
.
Platform
)
require
.
True
(
t
,
filter
.
StartTime
.
Before
(
filter
.
EndTime
))
require
.
WithinDuration
(
t
,
before
.
Add
(
-
30
*
24
*
time
.
Hour
),
filter
.
StartTime
,
2
*
time
.
Second
)
require
.
WithinDuration
(
t
,
after
,
filter
.
EndTime
,
2
*
time
.
Second
)
}
func
TestParseOpsOpenAITokenStatsFilter_WithTopN
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/?time_range=1h&platform=openai&group_id=12&top_n=50"
,
nil
,
)
filter
,
err
:=
parseOpsOpenAITokenStatsFilter
(
c
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"1h"
,
filter
.
TimeRange
)
require
.
Equal
(
t
,
"openai"
,
filter
.
Platform
)
require
.
NotNil
(
t
,
filter
.
GroupID
)
require
.
Equal
(
t
,
int64
(
12
),
*
filter
.
GroupID
)
require
.
Equal
(
t
,
50
,
filter
.
TopN
)
require
.
Equal
(
t
,
0
,
filter
.
Page
)
require
.
Equal
(
t
,
0
,
filter
.
PageSize
)
}
func
TestParseOpsOpenAITokenStatsFilter_InvalidParams
(
t
*
testing
.
T
)
{
tests
:=
[]
string
{
"/?time_range=7d"
,
"/?group_id=0"
,
"/?group_id=abc"
,
"/?top_n=0"
,
"/?top_n=101"
,
"/?top_n=10&page=1"
,
"/?top_n=10&page_size=20"
,
"/?page=0"
,
"/?page_size=0"
,
"/?page_size=101"
,
}
gin
.
SetMode
(
gin
.
TestMode
)
for
_
,
rawURL
:=
range
tests
{
w
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
w
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodGet
,
rawURL
,
nil
)
_
,
err
:=
parseOpsOpenAITokenStatsFilter
(
c
)
require
.
Error
(
t
,
err
,
"url=%s"
,
rawURL
)
}
}
func
TestParseOpsTimeRange
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
w
:=
httptest
.
NewRecorder
()
...
...
backend/internal/handler/admin/admin_service_stub_test.go
View file @
3d79773b
...
...
@@ -22,6 +22,15 @@ type stubAdminService struct {
updatedProxyIDs
[]
int64
updatedProxies
[]
*
service
.
UpdateProxyInput
testedProxyIDs
[]
int64
createAccountErr
error
updateAccountErr
error
bulkUpdateAccountErr
error
checkMixedErr
error
lastMixedCheck
struct
{
accountID
int64
platform
string
groupIDs
[]
int64
}
mu
sync
.
Mutex
}
...
...
@@ -188,11 +197,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre
s
.
mu
.
Lock
()
s
.
createdAccounts
=
append
(
s
.
createdAccounts
,
input
)
s
.
mu
.
Unlock
()
if
s
.
createAccountErr
!=
nil
{
return
nil
,
s
.
createAccountErr
}
account
:=
service
.
Account
{
ID
:
300
,
Name
:
input
.
Name
,
Status
:
service
.
StatusActive
}
return
&
account
,
nil
}
func
(
s
*
stubAdminService
)
UpdateAccount
(
ctx
context
.
Context
,
id
int64
,
input
*
service
.
UpdateAccountInput
)
(
*
service
.
Account
,
error
)
{
if
s
.
updateAccountErr
!=
nil
{
return
nil
,
s
.
updateAccountErr
}
account
:=
service
.
Account
{
ID
:
id
,
Name
:
input
.
Name
,
Status
:
service
.
StatusActive
}
return
&
account
,
nil
}
...
...
@@ -221,7 +236,17 @@ func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64,
}
func
(
s
*
stubAdminService
)
BulkUpdateAccounts
(
ctx
context
.
Context
,
input
*
service
.
BulkUpdateAccountsInput
)
(
*
service
.
BulkUpdateAccountsResult
,
error
)
{
return
&
service
.
BulkUpdateAccountsResult
{
Success
:
1
,
Failed
:
0
,
SuccessIDs
:
[]
int64
{
1
}},
nil
if
s
.
bulkUpdateAccountErr
!=
nil
{
return
nil
,
s
.
bulkUpdateAccountErr
}
return
&
service
.
BulkUpdateAccountsResult
{
Success
:
len
(
input
.
AccountIDs
),
Failed
:
0
,
SuccessIDs
:
input
.
AccountIDs
},
nil
}
func
(
s
*
stubAdminService
)
CheckMixedChannelRisk
(
ctx
context
.
Context
,
currentAccountID
int64
,
currentAccountPlatform
string
,
groupIDs
[]
int64
)
error
{
s
.
lastMixedCheck
.
accountID
=
currentAccountID
s
.
lastMixedCheck
.
platform
=
currentAccountPlatform
s
.
lastMixedCheck
.
groupIDs
=
append
([]
int64
(
nil
),
groupIDs
...
)
return
s
.
checkMixedErr
}
func
(
s
*
stubAdminService
)
ListProxies
(
ctx
context
.
Context
,
page
,
pageSize
int
,
protocol
,
status
,
search
string
)
([]
service
.
Proxy
,
int64
,
error
)
{
...
...
@@ -327,6 +352,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
return
&
service
.
ProxyTestResult
{
Success
:
true
,
Message
:
"ok"
},
nil
}
func
(
s
*
stubAdminService
)
CheckProxyQuality
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
ProxyQualityCheckResult
,
error
)
{
return
&
service
.
ProxyQualityCheckResult
{
ProxyID
:
id
,
Score
:
95
,
Grade
:
"A"
,
Summary
:
"通过 5 项,告警 0 项,失败 0 项,挑战 0 项"
,
PassedCount
:
5
,
WarnCount
:
0
,
FailedCount
:
0
,
ChallengeCount
:
0
,
CheckedAt
:
time
.
Now
()
.
Unix
(),
Items
:
[]
service
.
ProxyQualityCheckItem
{
{
Target
:
"base_connectivity"
,
Status
:
"pass"
,
Message
:
"ok"
},
{
Target
:
"openai"
,
Status
:
"pass"
,
HTTPStatus
:
401
},
{
Target
:
"anthropic"
,
Status
:
"pass"
,
HTTPStatus
:
401
},
{
Target
:
"gemini"
,
Status
:
"pass"
,
HTTPStatus
:
200
},
{
Target
:
"sora"
,
Status
:
"pass"
,
HTTPStatus
:
401
},
},
},
nil
}
func
(
s
*
stubAdminService
)
ListRedeemCodes
(
ctx
context
.
Context
,
page
,
pageSize
int
,
codeType
,
status
,
search
string
)
([]
service
.
RedeemCode
,
int64
,
error
)
{
return
s
.
redeems
,
int64
(
len
(
s
.
redeems
)),
nil
}
...
...
@@ -361,5 +407,23 @@ func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []
return
nil
}
func
(
s
*
stubAdminService
)
AdminUpdateAPIKeyGroupID
(
ctx
context
.
Context
,
keyID
int64
,
groupID
*
int64
)
(
*
service
.
AdminUpdateAPIKeyGroupIDResult
,
error
)
{
for
i
:=
range
s
.
apiKeys
{
if
s
.
apiKeys
[
i
]
.
ID
==
keyID
{
k
:=
s
.
apiKeys
[
i
]
if
groupID
!=
nil
{
if
*
groupID
==
0
{
k
.
GroupID
=
nil
}
else
{
gid
:=
*
groupID
k
.
GroupID
=
&
gid
}
}
return
&
service
.
AdminUpdateAPIKeyGroupIDResult
{
APIKey
:
&
k
},
nil
}
}
return
nil
,
service
.
ErrAPIKeyNotFound
}
// Ensure stub implements interface.
var
_
service
.
AdminService
=
(
*
stubAdminService
)(
nil
)
backend/internal/handler/admin/apikey_handler.go
0 → 100644
View file @
3d79773b
package
admin
import
(
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AdminAPIKeyHandler handles admin API key management
type
AdminAPIKeyHandler
struct
{
adminService
service
.
AdminService
}
// NewAdminAPIKeyHandler creates a new admin API key handler
func
NewAdminAPIKeyHandler
(
adminService
service
.
AdminService
)
*
AdminAPIKeyHandler
{
return
&
AdminAPIKeyHandler
{
adminService
:
adminService
,
}
}
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
type
AdminUpdateAPIKeyGroupRequest
struct
{
GroupID
*
int64
`json:"group_id"`
// nil=不修改, 0=解绑, >0=绑定到目标分组
}
// UpdateGroup handles updating an API key's group binding
// PUT /api/v1/admin/api-keys/:id
func
(
h
*
AdminAPIKeyHandler
)
UpdateGroup
(
c
*
gin
.
Context
)
{
keyID
,
err
:=
strconv
.
ParseInt
(
c
.
Param
(
"id"
),
10
,
64
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid API key ID"
)
return
}
var
req
AdminUpdateAPIKeyGroupRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
result
,
err
:=
h
.
adminService
.
AdminUpdateAPIKeyGroupID
(
c
.
Request
.
Context
(),
keyID
,
req
.
GroupID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
resp
:=
struct
{
APIKey
*
dto
.
APIKey
`json:"api_key"`
AutoGrantedGroupAccess
bool
`json:"auto_granted_group_access"`
GrantedGroupID
*
int64
`json:"granted_group_id,omitempty"`
GrantedGroupName
string
`json:"granted_group_name,omitempty"`
}{
APIKey
:
dto
.
APIKeyFromService
(
result
.
APIKey
),
AutoGrantedGroupAccess
:
result
.
AutoGrantedGroupAccess
,
GrantedGroupID
:
result
.
GrantedGroupID
,
GrantedGroupName
:
result
.
GrantedGroupName
,
}
response
.
Success
(
c
,
resp
)
}
backend/internal/handler/admin/apikey_handler_test.go
0 → 100644
View file @
3d79773b
package
admin
import
(
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func
setupAPIKeyHandler
(
adminSvc
service
.
AdminService
)
*
gin
.
Engine
{
gin
.
SetMode
(
gin
.
TestMode
)
router
:=
gin
.
New
()
h
:=
NewAdminAPIKeyHandler
(
adminSvc
)
router
.
PUT
(
"/api/v1/admin/api-keys/:id"
,
h
.
UpdateGroup
)
return
router
}
func
TestAdminAPIKeyHandler_UpdateGroup_InvalidID
(
t
*
testing
.
T
)
{
router
:=
setupAPIKeyHandler
(
newStubAdminService
())
body
:=
`{"group_id": 2}`
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPut
,
"/api/v1/admin/api-keys/abc"
,
bytes
.
NewBufferString
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"Invalid API key ID"
)
}
func
TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON
(
t
*
testing
.
T
)
{
router
:=
setupAPIKeyHandler
(
newStubAdminService
())
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPut
,
"/api/v1/admin/api-keys/10"
,
bytes
.
NewBufferString
(
`{bad json`
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"Invalid request"
)
}
func
TestAdminAPIKeyHandler_UpdateGroup_KeyNotFound
(
t
*
testing
.
T
)
{
router
:=
setupAPIKeyHandler
(
newStubAdminService
())
body
:=
`{"group_id": 2}`
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPut
,
"/api/v1/admin/api-keys/999"
,
bytes
.
NewBufferString
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
// ErrAPIKeyNotFound maps to 404
require
.
Equal
(
t
,
http
.
StatusNotFound
,
rec
.
Code
)
}
func
TestAdminAPIKeyHandler_UpdateGroup_BindGroup
(
t
*
testing
.
T
)
{
router
:=
setupAPIKeyHandler
(
newStubAdminService
())
body
:=
`{"group_id": 2}`
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPut
,
"/api/v1/admin/api-keys/10"
,
bytes
.
NewBufferString
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
var
resp
struct
{
Code
int
`json:"code"`
Data
json
.
RawMessage
`json:"data"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
0
,
resp
.
Code
)
var
data
struct
{
APIKey
struct
{
ID
int64
`json:"id"`
GroupID
*
int64
`json:"group_id"`
}
`json:"api_key"`
AutoGrantedGroupAccess
bool
`json:"auto_granted_group_access"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
resp
.
Data
,
&
data
))
require
.
Equal
(
t
,
int64
(
10
),
data
.
APIKey
.
ID
)
require
.
NotNil
(
t
,
data
.
APIKey
.
GroupID
)
require
.
Equal
(
t
,
int64
(
2
),
*
data
.
APIKey
.
GroupID
)
}
func
TestAdminAPIKeyHandler_UpdateGroup_Unbind
(
t
*
testing
.
T
)
{
svc
:=
newStubAdminService
()
gid
:=
int64
(
2
)
svc
.
apiKeys
[
0
]
.
GroupID
=
&
gid
router
:=
setupAPIKeyHandler
(
svc
)
body
:=
`{"group_id": 0}`
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPut
,
"/api/v1/admin/api-keys/10"
,
bytes
.
NewBufferString
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
var
resp
struct
{
Data
struct
{
APIKey
struct
{
GroupID
*
int64
`json:"group_id"`
}
`json:"api_key"`
}
`json:"data"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Nil
(
t
,
resp
.
Data
.
APIKey
.
GroupID
)
}
func
TestAdminAPIKeyHandler_UpdateGroup_ServiceError
(
t
*
testing
.
T
)
{
svc
:=
&
failingUpdateGroupService
{
stubAdminService
:
newStubAdminService
(),
err
:
errors
.
New
(
"internal failure"
),
}
router
:=
setupAPIKeyHandler
(
svc
)
body
:=
`{"group_id": 2}`
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPut
,
"/api/v1/admin/api-keys/10"
,
bytes
.
NewBufferString
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusInternalServerError
,
rec
.
Code
)
}
// H2: empty body → group_id is nil → no-op, returns original key
func
TestAdminAPIKeyHandler_UpdateGroup_EmptyBody_NoChange
(
t
*
testing
.
T
)
{
router
:=
setupAPIKeyHandler
(
newStubAdminService
())
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPut
,
"/api/v1/admin/api-keys/10"
,
bytes
.
NewBufferString
(
`{}`
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
var
resp
struct
{
Code
int
`json:"code"`
Data
struct
{
APIKey
struct
{
ID
int64
`json:"id"`
}
`json:"api_key"`
}
`json:"data"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
resp
))
require
.
Equal
(
t
,
0
,
resp
.
Code
)
require
.
Equal
(
t
,
int64
(
10
),
resp
.
Data
.
APIKey
.
ID
)
}
// M2: service returns GROUP_NOT_ACTIVE → handler maps to 400
func
TestAdminAPIKeyHandler_UpdateGroup_GroupNotActive
(
t
*
testing
.
T
)
{
svc
:=
&
failingUpdateGroupService
{
stubAdminService
:
newStubAdminService
(),
err
:
infraerrors
.
BadRequest
(
"GROUP_NOT_ACTIVE"
,
"target group is not active"
),
}
router
:=
setupAPIKeyHandler
(
svc
)
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPut
,
"/api/v1/admin/api-keys/10"
,
bytes
.
NewBufferString
(
`{"group_id": 5}`
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"GROUP_NOT_ACTIVE"
)
}
// M2: service returns INVALID_GROUP_ID → handler maps to 400
func
TestAdminAPIKeyHandler_UpdateGroup_NegativeGroupID
(
t
*
testing
.
T
)
{
svc
:=
&
failingUpdateGroupService
{
stubAdminService
:
newStubAdminService
(),
err
:
infraerrors
.
BadRequest
(
"INVALID_GROUP_ID"
,
"group_id must be non-negative"
),
}
router
:=
setupAPIKeyHandler
(
svc
)
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodPut
,
"/api/v1/admin/api-keys/10"
,
bytes
.
NewBufferString
(
`{"group_id": -5}`
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
require
.
Contains
(
t
,
rec
.
Body
.
String
(),
"INVALID_GROUP_ID"
)
}
// failingUpdateGroupService overrides AdminUpdateAPIKeyGroupID to return an error.
type
failingUpdateGroupService
struct
{
*
stubAdminService
err
error
}
func
(
f
*
failingUpdateGroupService
)
AdminUpdateAPIKeyGroupID
(
_
context
.
Context
,
_
int64
,
_
*
int64
)
(
*
service
.
AdminUpdateAPIKeyGroupIDResult
,
error
)
{
return
nil
,
f
.
err
}
backend/internal/handler/admin/batch_update_credentials_test.go
0 → 100644
View file @
3d79773b
//go:build unit
package
admin
import
(
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。
type
failingAdminService
struct
{
*
stubAdminService
failOnAccountID
int64
updateCallCount
atomic
.
Int64
}
func
(
f
*
failingAdminService
)
UpdateAccount
(
ctx
context
.
Context
,
id
int64
,
input
*
service
.
UpdateAccountInput
)
(
*
service
.
Account
,
error
)
{
f
.
updateCallCount
.
Add
(
1
)
if
id
==
f
.
failOnAccountID
{
return
nil
,
errors
.
New
(
"database error"
)
}
return
f
.
stubAdminService
.
UpdateAccount
(
ctx
,
id
,
input
)
}
func
setupAccountHandlerWithService
(
adminSvc
service
.
AdminService
)
(
*
gin
.
Engine
,
*
AccountHandler
)
{
gin
.
SetMode
(
gin
.
TestMode
)
router
:=
gin
.
New
()
handler
:=
NewAccountHandler
(
adminSvc
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
,
nil
)
router
.
POST
(
"/api/v1/admin/accounts/batch-update-credentials"
,
handler
.
BatchUpdateCredentials
)
return
router
,
handler
}
func
TestBatchUpdateCredentials_AllSuccess
(
t
*
testing
.
T
)
{
svc
:=
&
failingAdminService
{
stubAdminService
:
newStubAdminService
()}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
body
,
_
:=
json
.
Marshal
(
BatchUpdateCredentialsRequest
{
AccountIDs
:
[]
int64
{
1
,
2
,
3
},
Field
:
"account_uuid"
,
Value
:
"test-uuid"
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
,
"全部成功时应返回 200"
)
require
.
Equal
(
t
,
int64
(
3
),
svc
.
updateCallCount
.
Load
(),
"应调用 3 次 UpdateAccount"
)
}
func
TestBatchUpdateCredentials_PartialFailure
(
t
*
testing
.
T
)
{
// 让第 2 个账号(ID=2)更新时失败
svc
:=
&
failingAdminService
{
stubAdminService
:
newStubAdminService
(),
failOnAccountID
:
2
,
}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
body
,
_
:=
json
.
Marshal
(
BatchUpdateCredentialsRequest
{
AccountIDs
:
[]
int64
{
1
,
2
,
3
},
Field
:
"org_uuid"
,
Value
:
"test-org"
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
// 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
,
"批量更新返回 200 + 成功/失败明细"
)
var
resp
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
w
.
Body
.
Bytes
(),
&
resp
))
data
:=
resp
[
"data"
]
.
(
map
[
string
]
any
)
require
.
Equal
(
t
,
float64
(
2
),
data
[
"success"
],
"应有 2 个成功"
)
require
.
Equal
(
t
,
float64
(
1
),
data
[
"failed"
],
"应有 1 个失败"
)
// 所有 3 个账号都会被尝试更新(非 fail-fast)
require
.
Equal
(
t
,
int64
(
3
),
svc
.
updateCallCount
.
Load
(),
"应调用 3 次 UpdateAccount(逐个尝试,失败后继续)"
)
}
func
TestBatchUpdateCredentials_FirstAccountNotFound
(
t
*
testing
.
T
)
{
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
svc
:=
&
getAccountFailingService
{
stubAdminService
:
newStubAdminService
(),
failOnAccountID
:
1
,
}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
body
,
_
:=
json
.
Marshal
(
BatchUpdateCredentialsRequest
{
AccountIDs
:
[]
int64
{
1
,
2
,
3
},
Field
:
"account_uuid"
,
Value
:
"test"
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusNotFound
,
w
.
Code
,
"第一阶段验证失败应返回 404"
)
}
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
type
getAccountFailingService
struct
{
*
stubAdminService
failOnAccountID
int64
}
func
(
f
*
getAccountFailingService
)
GetAccount
(
ctx
context
.
Context
,
id
int64
)
(
*
service
.
Account
,
error
)
{
if
id
==
f
.
failOnAccountID
{
return
nil
,
errors
.
New
(
"not found"
)
}
return
f
.
stubAdminService
.
GetAccount
(
ctx
,
id
)
}
func
TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool
(
t
*
testing
.
T
)
{
svc
:=
&
failingAdminService
{
stubAdminService
:
newStubAdminService
()}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
// intercept_warmup_requests 传入非 bool 类型(string),应返回 400
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"account_ids"
:
[]
int64
{
1
},
"field"
:
"intercept_warmup_requests"
,
"value"
:
"not-a-bool"
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
w
.
Code
,
"intercept_warmup_requests 传入非 bool 值应返回 400"
)
}
func
TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool
(
t
*
testing
.
T
)
{
svc
:=
&
failingAdminService
{
stubAdminService
:
newStubAdminService
()}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"account_ids"
:
[]
int64
{
1
},
"field"
:
"intercept_warmup_requests"
,
"value"
:
true
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
,
"intercept_warmup_requests 传入合法 bool 值应返回 200"
)
}
func
TestBatchUpdateCredentials_AccountUUID_NonString
(
t
*
testing
.
T
)
{
svc
:=
&
failingAdminService
{
stubAdminService
:
newStubAdminService
()}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
// account_uuid 传入非 string 类型(number),应返回 400
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"account_ids"
:
[]
int64
{
1
},
"field"
:
"account_uuid"
,
"value"
:
12345
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
w
.
Code
,
"account_uuid 传入非 string 值应返回 400"
)
}
func
TestBatchUpdateCredentials_AccountUUID_NullValue
(
t
*
testing
.
T
)
{
svc
:=
&
failingAdminService
{
stubAdminService
:
newStubAdminService
()}
router
,
_
:=
setupAccountHandlerWithService
(
svc
)
// account_uuid 传入 null(设置为空),应正常通过
body
,
_
:=
json
.
Marshal
(
map
[
string
]
any
{
"account_ids"
:
[]
int64
{
1
},
"field"
:
"account_uuid"
,
"value"
:
nil
,
})
w
:=
httptest
.
NewRecorder
()
req
,
_
:=
http
.
NewRequest
(
"POST"
,
"/api/v1/admin/accounts/batch-update-credentials"
,
bytes
.
NewReader
(
body
))
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
router
.
ServeHTTP
(
w
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
w
.
Code
,
"account_uuid 传入 null 应返回 200"
)
}
backend/internal/handler/admin/dashboard_handler.go
View file @
3d79773b
package
admin
import
(
"encoding/json"
"errors"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
...
...
@@ -186,7 +188,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id,
request_type,
stream, billing_type
func
(
h
*
DashboardHandler
)
GetUsageTrend
(
c
*
gin
.
Context
)
{
startTime
,
endTime
:=
parseTimeRange
(
c
)
granularity
:=
c
.
DefaultQuery
(
"granularity"
,
"day"
)
...
...
@@ -194,6 +196,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// Parse optional filter params
var
userID
,
apiKeyID
,
accountID
,
groupID
int64
var
model
string
var
requestType
*
int16
var
stream
*
bool
var
billingType
*
int8
...
...
@@ -220,9 +223,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
if
modelStr
:=
c
.
Query
(
"model"
);
modelStr
!=
""
{
model
=
modelStr
}
if
streamStr
:=
c
.
Query
(
"stream"
);
streamStr
!=
""
{
if
requestTypeStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"request_type"
));
requestTypeStr
!=
""
{
parsed
,
err
:=
service
.
ParseUsageRequestType
(
requestTypeStr
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
value
:=
int16
(
parsed
)
requestType
=
&
value
}
else
if
streamStr
:=
c
.
Query
(
"stream"
);
streamStr
!=
""
{
if
streamVal
,
err
:=
strconv
.
ParseBool
(
streamStr
);
err
==
nil
{
stream
=
&
streamVal
}
else
{
response
.
BadRequest
(
c
,
"Invalid stream value, use true or false"
)
return
}
}
if
billingTypeStr
:=
c
.
Query
(
"billing_type"
);
billingTypeStr
!=
""
{
...
...
@@ -235,7 +249,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
}
}
trend
,
err
:=
h
.
dashboardService
.
GetUsageTrendWithFilters
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
stream
,
billingType
)
trend
,
err
:=
h
.
dashboardService
.
GetUsageTrendWithFilters
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
granularity
,
userID
,
apiKeyID
,
accountID
,
groupID
,
model
,
requestType
,
stream
,
billingType
)
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get usage trend"
)
return
...
...
@@ -251,12 +265,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id,
request_type,
stream, billing_type
func
(
h
*
DashboardHandler
)
GetModelStats
(
c
*
gin
.
Context
)
{
startTime
,
endTime
:=
parseTimeRange
(
c
)
// Parse optional filter params
var
userID
,
apiKeyID
,
accountID
,
groupID
int64
var
requestType
*
int16
var
stream
*
bool
var
billingType
*
int8
...
...
@@ -280,9 +295,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
groupID
=
id
}
}
if
streamStr
:=
c
.
Query
(
"stream"
);
streamStr
!=
""
{
if
requestTypeStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"request_type"
));
requestTypeStr
!=
""
{
parsed
,
err
:=
service
.
ParseUsageRequestType
(
requestTypeStr
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
value
:=
int16
(
parsed
)
requestType
=
&
value
}
else
if
streamStr
:=
c
.
Query
(
"stream"
);
streamStr
!=
""
{
if
streamVal
,
err
:=
strconv
.
ParseBool
(
streamStr
);
err
==
nil
{
stream
=
&
streamVal
}
else
{
response
.
BadRequest
(
c
,
"Invalid stream value, use true or false"
)
return
}
}
if
billingTypeStr
:=
c
.
Query
(
"billing_type"
);
billingTypeStr
!=
""
{
...
...
@@ -295,7 +321,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
stats
,
err
:=
h
.
dashboardService
.
GetModelStatsWithFilters
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
stream
,
billingType
)
stats
,
err
:=
h
.
dashboardService
.
GetModelStatsWithFilters
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
requestType
,
stream
,
billingType
)
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get model statistics"
)
return
...
...
@@ -308,6 +334,76 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
})
}
// GetGroupStats handles getting group usage statistics
// GET /api/v1/admin/dashboard/groups
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
func
(
h
*
DashboardHandler
)
GetGroupStats
(
c
*
gin
.
Context
)
{
startTime
,
endTime
:=
parseTimeRange
(
c
)
var
userID
,
apiKeyID
,
accountID
,
groupID
int64
var
requestType
*
int16
var
stream
*
bool
var
billingType
*
int8
if
userIDStr
:=
c
.
Query
(
"user_id"
);
userIDStr
!=
""
{
if
id
,
err
:=
strconv
.
ParseInt
(
userIDStr
,
10
,
64
);
err
==
nil
{
userID
=
id
}
}
if
apiKeyIDStr
:=
c
.
Query
(
"api_key_id"
);
apiKeyIDStr
!=
""
{
if
id
,
err
:=
strconv
.
ParseInt
(
apiKeyIDStr
,
10
,
64
);
err
==
nil
{
apiKeyID
=
id
}
}
if
accountIDStr
:=
c
.
Query
(
"account_id"
);
accountIDStr
!=
""
{
if
id
,
err
:=
strconv
.
ParseInt
(
accountIDStr
,
10
,
64
);
err
==
nil
{
accountID
=
id
}
}
if
groupIDStr
:=
c
.
Query
(
"group_id"
);
groupIDStr
!=
""
{
if
id
,
err
:=
strconv
.
ParseInt
(
groupIDStr
,
10
,
64
);
err
==
nil
{
groupID
=
id
}
}
if
requestTypeStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"request_type"
));
requestTypeStr
!=
""
{
parsed
,
err
:=
service
.
ParseUsageRequestType
(
requestTypeStr
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
value
:=
int16
(
parsed
)
requestType
=
&
value
}
else
if
streamStr
:=
c
.
Query
(
"stream"
);
streamStr
!=
""
{
if
streamVal
,
err
:=
strconv
.
ParseBool
(
streamStr
);
err
==
nil
{
stream
=
&
streamVal
}
else
{
response
.
BadRequest
(
c
,
"Invalid stream value, use true or false"
)
return
}
}
if
billingTypeStr
:=
c
.
Query
(
"billing_type"
);
billingTypeStr
!=
""
{
if
v
,
err
:=
strconv
.
ParseInt
(
billingTypeStr
,
10
,
8
);
err
==
nil
{
bt
:=
int8
(
v
)
billingType
=
&
bt
}
else
{
response
.
BadRequest
(
c
,
"Invalid billing_type"
)
return
}
}
stats
,
err
:=
h
.
dashboardService
.
GetGroupStatsWithFilters
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
userID
,
apiKeyID
,
accountID
,
groupID
,
requestType
,
stream
,
billingType
)
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get group statistics"
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"groups"
:
stats
,
"start_date"
:
startTime
.
Format
(
"2006-01-02"
),
"end_date"
:
endTime
.
Add
(
-
24
*
time
.
Hour
)
.
Format
(
"2006-01-02"
),
})
}
// GetAPIKeyUsageTrend handles getting API key usage trend data
// GET /api/v1/admin/dashboard/api-keys-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
...
...
@@ -365,6 +461,9 @@ type BatchUsersUsageRequest struct {
UserIDs
[]
int64
`json:"user_ids" binding:"required"`
}
var
dashboardBatchUsersUsageCache
=
newSnapshotCache
(
30
*
time
.
Second
)
var
dashboardBatchAPIKeysUsageCache
=
newSnapshotCache
(
30
*
time
.
Second
)
// GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage
func
(
h
*
DashboardHandler
)
GetBatchUsersUsage
(
c
*
gin
.
Context
)
{
...
...
@@ -374,18 +473,34 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
if
len
(
req
.
UserIDs
)
==
0
{
userIDs
:=
normalizeInt64IDList
(
req
.
UserIDs
)
if
len
(
userIDs
)
==
0
{
response
.
Success
(
c
,
gin
.
H
{
"stats"
:
map
[
string
]
any
{}})
return
}
stats
,
err
:=
h
.
dashboardService
.
GetBatchUserUsageStats
(
c
.
Request
.
Context
(),
req
.
UserIDs
)
keyRaw
,
_
:=
json
.
Marshal
(
struct
{
UserIDs
[]
int64
`json:"user_ids"`
}{
UserIDs
:
userIDs
,
})
cacheKey
:=
string
(
keyRaw
)
if
cached
,
ok
:=
dashboardBatchUsersUsageCache
.
Get
(
cacheKey
);
ok
{
c
.
Header
(
"X-Snapshot-Cache"
,
"hit"
)
response
.
Success
(
c
,
cached
.
Payload
)
return
}
stats
,
err
:=
h
.
dashboardService
.
GetBatchUserUsageStats
(
c
.
Request
.
Context
(),
userIDs
,
time
.
Time
{},
time
.
Time
{})
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get user usage stats"
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"stats"
:
stats
})
payload
:=
gin
.
H
{
"stats"
:
stats
}
dashboardBatchUsersUsageCache
.
Set
(
cacheKey
,
payload
)
c
.
Header
(
"X-Snapshot-Cache"
,
"miss"
)
response
.
Success
(
c
,
payload
)
}
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
...
...
@@ -402,16 +517,32 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return
}
if
len
(
req
.
APIKeyIDs
)
==
0
{
apiKeyIDs
:=
normalizeInt64IDList
(
req
.
APIKeyIDs
)
if
len
(
apiKeyIDs
)
==
0
{
response
.
Success
(
c
,
gin
.
H
{
"stats"
:
map
[
string
]
any
{}})
return
}
stats
,
err
:=
h
.
dashboardService
.
GetBatchAPIKeyUsageStats
(
c
.
Request
.
Context
(),
req
.
APIKeyIDs
)
keyRaw
,
_
:=
json
.
Marshal
(
struct
{
APIKeyIDs
[]
int64
`json:"api_key_ids"`
}{
APIKeyIDs
:
apiKeyIDs
,
})
cacheKey
:=
string
(
keyRaw
)
if
cached
,
ok
:=
dashboardBatchAPIKeysUsageCache
.
Get
(
cacheKey
);
ok
{
c
.
Header
(
"X-Snapshot-Cache"
,
"hit"
)
response
.
Success
(
c
,
cached
.
Payload
)
return
}
stats
,
err
:=
h
.
dashboardService
.
GetBatchAPIKeyUsageStats
(
c
.
Request
.
Context
(),
apiKeyIDs
,
time
.
Time
{},
time
.
Time
{})
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get API key usage stats"
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"stats"
:
stats
})
payload
:=
gin
.
H
{
"stats"
:
stats
}
dashboardBatchAPIKeysUsageCache
.
Set
(
cacheKey
,
payload
)
c
.
Header
(
"X-Snapshot-Cache"
,
"miss"
)
response
.
Success
(
c
,
payload
)
}
backend/internal/handler/admin/dashboard_handler_request_type_test.go
0 → 100644
View file @
3d79773b
package
admin
import
(
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type
dashboardUsageRepoCapture
struct
{
service
.
UsageLogRepository
trendRequestType
*
int16
trendStream
*
bool
modelRequestType
*
int16
modelStream
*
bool
}
func
(
s
*
dashboardUsageRepoCapture
)
GetUsageTrendWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
granularity
string
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
model
string
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
,
)
([]
usagestats
.
TrendDataPoint
,
error
)
{
s
.
trendRequestType
=
requestType
s
.
trendStream
=
stream
return
[]
usagestats
.
TrendDataPoint
{},
nil
}
func
(
s
*
dashboardUsageRepoCapture
)
GetModelStatsWithFilters
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
,
userID
,
apiKeyID
,
accountID
,
groupID
int64
,
requestType
*
int16
,
stream
*
bool
,
billingType
*
int8
,
)
([]
usagestats
.
ModelStat
,
error
)
{
s
.
modelRequestType
=
requestType
s
.
modelStream
=
stream
return
[]
usagestats
.
ModelStat
{},
nil
}
func
newDashboardRequestTypeTestRouter
(
repo
*
dashboardUsageRepoCapture
)
*
gin
.
Engine
{
gin
.
SetMode
(
gin
.
TestMode
)
dashboardSvc
:=
service
.
NewDashboardService
(
repo
,
nil
,
nil
,
nil
)
handler
:=
NewDashboardHandler
(
dashboardSvc
,
nil
)
router
:=
gin
.
New
()
router
.
GET
(
"/admin/dashboard/trend"
,
handler
.
GetUsageTrend
)
router
.
GET
(
"/admin/dashboard/models"
,
handler
.
GetModelStats
)
return
router
}
func
TestDashboardTrendRequestTypePriority
(
t
*
testing
.
T
)
{
repo
:=
&
dashboardUsageRepoCapture
{}
router
:=
newDashboardRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/dashboard/trend?request_type=ws_v2&stream=bad"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
NotNil
(
t
,
repo
.
trendRequestType
)
require
.
Equal
(
t
,
int16
(
service
.
RequestTypeWSV2
),
*
repo
.
trendRequestType
)
require
.
Nil
(
t
,
repo
.
trendStream
)
}
func
TestDashboardTrendInvalidRequestType
(
t
*
testing
.
T
)
{
repo
:=
&
dashboardUsageRepoCapture
{}
router
:=
newDashboardRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/dashboard/trend?request_type=bad"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestDashboardTrendInvalidStream
(
t
*
testing
.
T
)
{
repo
:=
&
dashboardUsageRepoCapture
{}
router
:=
newDashboardRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/dashboard/trend?stream=bad"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestDashboardModelStatsRequestTypePriority
(
t
*
testing
.
T
)
{
repo
:=
&
dashboardUsageRepoCapture
{}
router
:=
newDashboardRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/dashboard/models?request_type=sync&stream=bad"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
NotNil
(
t
,
repo
.
modelRequestType
)
require
.
Equal
(
t
,
int16
(
service
.
RequestTypeSync
),
*
repo
.
modelRequestType
)
require
.
Nil
(
t
,
repo
.
modelStream
)
}
func
TestDashboardModelStatsInvalidRequestType
(
t
*
testing
.
T
)
{
repo
:=
&
dashboardUsageRepoCapture
{}
router
:=
newDashboardRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/dashboard/models?request_type=bad"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
func
TestDashboardModelStatsInvalidStream
(
t
*
testing
.
T
)
{
repo
:=
&
dashboardUsageRepoCapture
{}
router
:=
newDashboardRequestTypeTestRouter
(
repo
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/admin/dashboard/models?stream=bad"
,
nil
)
rec
:=
httptest
.
NewRecorder
()
router
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
}
backend/internal/handler/admin/dashboard_snapshot_v2_handler.go
0 → 100644
View file @
3d79773b
package
admin
import
(
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
var
dashboardSnapshotV2Cache
=
newSnapshotCache
(
30
*
time
.
Second
)
type
dashboardSnapshotV2Stats
struct
{
usagestats
.
DashboardStats
Uptime
int64
`json:"uptime"`
}
type
dashboardSnapshotV2Response
struct
{
GeneratedAt
string
`json:"generated_at"`
StartDate
string
`json:"start_date"`
EndDate
string
`json:"end_date"`
Granularity
string
`json:"granularity"`
Stats
*
dashboardSnapshotV2Stats
`json:"stats,omitempty"`
Trend
[]
usagestats
.
TrendDataPoint
`json:"trend,omitempty"`
Models
[]
usagestats
.
ModelStat
`json:"models,omitempty"`
Groups
[]
usagestats
.
GroupStat
`json:"groups,omitempty"`
UsersTrend
[]
usagestats
.
UserUsageTrendPoint
`json:"users_trend,omitempty"`
}
type
dashboardSnapshotV2Filters
struct
{
UserID
int64
APIKeyID
int64
AccountID
int64
GroupID
int64
Model
string
RequestType
*
int16
Stream
*
bool
BillingType
*
int8
}
type
dashboardSnapshotV2CacheKey
struct
{
StartTime
string
`json:"start_time"`
EndTime
string
`json:"end_time"`
Granularity
string
`json:"granularity"`
UserID
int64
`json:"user_id"`
APIKeyID
int64
`json:"api_key_id"`
AccountID
int64
`json:"account_id"`
GroupID
int64
`json:"group_id"`
Model
string
`json:"model"`
RequestType
*
int16
`json:"request_type"`
Stream
*
bool
`json:"stream"`
BillingType
*
int8
`json:"billing_type"`
IncludeStats
bool
`json:"include_stats"`
IncludeTrend
bool
`json:"include_trend"`
IncludeModels
bool
`json:"include_models"`
IncludeGroups
bool
`json:"include_groups"`
IncludeUsersTrend
bool
`json:"include_users_trend"`
UsersTrendLimit
int
`json:"users_trend_limit"`
}
func
(
h
*
DashboardHandler
)
GetSnapshotV2
(
c
*
gin
.
Context
)
{
startTime
,
endTime
:=
parseTimeRange
(
c
)
granularity
:=
strings
.
TrimSpace
(
c
.
DefaultQuery
(
"granularity"
,
"day"
))
if
granularity
!=
"hour"
{
granularity
=
"day"
}
includeStats
:=
parseBoolQueryWithDefault
(
c
.
Query
(
"include_stats"
),
true
)
includeTrend
:=
parseBoolQueryWithDefault
(
c
.
Query
(
"include_trend"
),
true
)
includeModels
:=
parseBoolQueryWithDefault
(
c
.
Query
(
"include_model_stats"
),
true
)
includeGroups
:=
parseBoolQueryWithDefault
(
c
.
Query
(
"include_group_stats"
),
false
)
includeUsersTrend
:=
parseBoolQueryWithDefault
(
c
.
Query
(
"include_users_trend"
),
false
)
usersTrendLimit
:=
12
if
raw
:=
strings
.
TrimSpace
(
c
.
Query
(
"users_trend_limit"
));
raw
!=
""
{
if
parsed
,
err
:=
strconv
.
Atoi
(
raw
);
err
==
nil
&&
parsed
>
0
&&
parsed
<=
50
{
usersTrendLimit
=
parsed
}
}
filters
,
err
:=
parseDashboardSnapshotV2Filters
(
c
)
if
err
!=
nil
{
response
.
BadRequest
(
c
,
err
.
Error
())
return
}
keyRaw
,
_
:=
json
.
Marshal
(
dashboardSnapshotV2CacheKey
{
StartTime
:
startTime
.
UTC
()
.
Format
(
time
.
RFC3339
),
EndTime
:
endTime
.
UTC
()
.
Format
(
time
.
RFC3339
),
Granularity
:
granularity
,
UserID
:
filters
.
UserID
,
APIKeyID
:
filters
.
APIKeyID
,
AccountID
:
filters
.
AccountID
,
GroupID
:
filters
.
GroupID
,
Model
:
filters
.
Model
,
RequestType
:
filters
.
RequestType
,
Stream
:
filters
.
Stream
,
BillingType
:
filters
.
BillingType
,
IncludeStats
:
includeStats
,
IncludeTrend
:
includeTrend
,
IncludeModels
:
includeModels
,
IncludeGroups
:
includeGroups
,
IncludeUsersTrend
:
includeUsersTrend
,
UsersTrendLimit
:
usersTrendLimit
,
})
cacheKey
:=
string
(
keyRaw
)
if
cached
,
ok
:=
dashboardSnapshotV2Cache
.
Get
(
cacheKey
);
ok
{
if
cached
.
ETag
!=
""
{
c
.
Header
(
"ETag"
,
cached
.
ETag
)
c
.
Header
(
"Vary"
,
"If-None-Match"
)
if
ifNoneMatchMatched
(
c
.
GetHeader
(
"If-None-Match"
),
cached
.
ETag
)
{
c
.
Status
(
http
.
StatusNotModified
)
return
}
}
c
.
Header
(
"X-Snapshot-Cache"
,
"hit"
)
response
.
Success
(
c
,
cached
.
Payload
)
return
}
resp
:=
&
dashboardSnapshotV2Response
{
GeneratedAt
:
time
.
Now
()
.
UTC
()
.
Format
(
time
.
RFC3339
),
StartDate
:
startTime
.
Format
(
"2006-01-02"
),
EndDate
:
endTime
.
Add
(
-
24
*
time
.
Hour
)
.
Format
(
"2006-01-02"
),
Granularity
:
granularity
,
}
if
includeStats
{
stats
,
err
:=
h
.
dashboardService
.
GetDashboardStats
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get dashboard statistics"
)
return
}
resp
.
Stats
=
&
dashboardSnapshotV2Stats
{
DashboardStats
:
*
stats
,
Uptime
:
int64
(
time
.
Since
(
h
.
startTime
)
.
Seconds
()),
}
}
if
includeTrend
{
trend
,
err
:=
h
.
dashboardService
.
GetUsageTrendWithFilters
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
granularity
,
filters
.
UserID
,
filters
.
APIKeyID
,
filters
.
AccountID
,
filters
.
GroupID
,
filters
.
Model
,
filters
.
RequestType
,
filters
.
Stream
,
filters
.
BillingType
,
)
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get usage trend"
)
return
}
resp
.
Trend
=
trend
}
if
includeModels
{
models
,
err
:=
h
.
dashboardService
.
GetModelStatsWithFilters
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
filters
.
UserID
,
filters
.
APIKeyID
,
filters
.
AccountID
,
filters
.
GroupID
,
filters
.
RequestType
,
filters
.
Stream
,
filters
.
BillingType
,
)
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get model statistics"
)
return
}
resp
.
Models
=
models
}
if
includeGroups
{
groups
,
err
:=
h
.
dashboardService
.
GetGroupStatsWithFilters
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
filters
.
UserID
,
filters
.
APIKeyID
,
filters
.
AccountID
,
filters
.
GroupID
,
filters
.
RequestType
,
filters
.
Stream
,
filters
.
BillingType
,
)
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get group statistics"
)
return
}
resp
.
Groups
=
groups
}
if
includeUsersTrend
{
usersTrend
,
err
:=
h
.
dashboardService
.
GetUserUsageTrend
(
c
.
Request
.
Context
(),
startTime
,
endTime
,
granularity
,
usersTrendLimit
,
)
if
err
!=
nil
{
response
.
Error
(
c
,
500
,
"Failed to get user usage trend"
)
return
}
resp
.
UsersTrend
=
usersTrend
}
cached
:=
dashboardSnapshotV2Cache
.
Set
(
cacheKey
,
resp
)
if
cached
.
ETag
!=
""
{
c
.
Header
(
"ETag"
,
cached
.
ETag
)
c
.
Header
(
"Vary"
,
"If-None-Match"
)
}
c
.
Header
(
"X-Snapshot-Cache"
,
"miss"
)
response
.
Success
(
c
,
resp
)
}
func
parseDashboardSnapshotV2Filters
(
c
*
gin
.
Context
)
(
*
dashboardSnapshotV2Filters
,
error
)
{
filters
:=
&
dashboardSnapshotV2Filters
{
Model
:
strings
.
TrimSpace
(
c
.
Query
(
"model"
)),
}
if
userIDStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"user_id"
));
userIDStr
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
userIDStr
,
10
,
64
)
if
err
!=
nil
{
return
nil
,
err
}
filters
.
UserID
=
id
}
if
apiKeyIDStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"api_key_id"
));
apiKeyIDStr
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
apiKeyIDStr
,
10
,
64
)
if
err
!=
nil
{
return
nil
,
err
}
filters
.
APIKeyID
=
id
}
if
accountIDStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"account_id"
));
accountIDStr
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
accountIDStr
,
10
,
64
)
if
err
!=
nil
{
return
nil
,
err
}
filters
.
AccountID
=
id
}
if
groupIDStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"group_id"
));
groupIDStr
!=
""
{
id
,
err
:=
strconv
.
ParseInt
(
groupIDStr
,
10
,
64
)
if
err
!=
nil
{
return
nil
,
err
}
filters
.
GroupID
=
id
}
if
requestTypeStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"request_type"
));
requestTypeStr
!=
""
{
parsed
,
err
:=
service
.
ParseUsageRequestType
(
requestTypeStr
)
if
err
!=
nil
{
return
nil
,
err
}
value
:=
int16
(
parsed
)
filters
.
RequestType
=
&
value
}
else
if
streamStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"stream"
));
streamStr
!=
""
{
streamVal
,
err
:=
strconv
.
ParseBool
(
streamStr
)
if
err
!=
nil
{
return
nil
,
err
}
filters
.
Stream
=
&
streamVal
}
if
billingTypeStr
:=
strings
.
TrimSpace
(
c
.
Query
(
"billing_type"
));
billingTypeStr
!=
""
{
v
,
err
:=
strconv
.
ParseInt
(
billingTypeStr
,
10
,
8
)
if
err
!=
nil
{
return
nil
,
err
}
bt
:=
int8
(
v
)
filters
.
BillingType
=
&
bt
}
return
filters
,
nil
}
backend/internal/handler/admin/data_management_handler.go
0 → 100644
View file @
3d79773b
package
admin
import
(
"context"
"strconv"
"strings"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type
DataManagementHandler
struct
{
dataManagementService
dataManagementService
}
func
NewDataManagementHandler
(
dataManagementService
*
service
.
DataManagementService
)
*
DataManagementHandler
{
return
&
DataManagementHandler
{
dataManagementService
:
dataManagementService
}
}
type
dataManagementService
interface
{
GetConfig
(
ctx
context
.
Context
)
(
service
.
DataManagementConfig
,
error
)
UpdateConfig
(
ctx
context
.
Context
,
cfg
service
.
DataManagementConfig
)
(
service
.
DataManagementConfig
,
error
)
ValidateS3
(
ctx
context
.
Context
,
cfg
service
.
DataManagementS3Config
)
(
service
.
DataManagementTestS3Result
,
error
)
CreateBackupJob
(
ctx
context
.
Context
,
input
service
.
DataManagementCreateBackupJobInput
)
(
service
.
DataManagementBackupJob
,
error
)
ListSourceProfiles
(
ctx
context
.
Context
,
sourceType
string
)
([]
service
.
DataManagementSourceProfile
,
error
)
CreateSourceProfile
(
ctx
context
.
Context
,
input
service
.
DataManagementCreateSourceProfileInput
)
(
service
.
DataManagementSourceProfile
,
error
)
UpdateSourceProfile
(
ctx
context
.
Context
,
input
service
.
DataManagementUpdateSourceProfileInput
)
(
service
.
DataManagementSourceProfile
,
error
)
DeleteSourceProfile
(
ctx
context
.
Context
,
sourceType
,
profileID
string
)
error
SetActiveSourceProfile
(
ctx
context
.
Context
,
sourceType
,
profileID
string
)
(
service
.
DataManagementSourceProfile
,
error
)
ListS3Profiles
(
ctx
context
.
Context
)
([]
service
.
DataManagementS3Profile
,
error
)
CreateS3Profile
(
ctx
context
.
Context
,
input
service
.
DataManagementCreateS3ProfileInput
)
(
service
.
DataManagementS3Profile
,
error
)
UpdateS3Profile
(
ctx
context
.
Context
,
input
service
.
DataManagementUpdateS3ProfileInput
)
(
service
.
DataManagementS3Profile
,
error
)
DeleteS3Profile
(
ctx
context
.
Context
,
profileID
string
)
error
SetActiveS3Profile
(
ctx
context
.
Context
,
profileID
string
)
(
service
.
DataManagementS3Profile
,
error
)
ListBackupJobs
(
ctx
context
.
Context
,
input
service
.
DataManagementListBackupJobsInput
)
(
service
.
DataManagementListBackupJobsResult
,
error
)
GetBackupJob
(
ctx
context
.
Context
,
jobID
string
)
(
service
.
DataManagementBackupJob
,
error
)
EnsureAgentEnabled
(
ctx
context
.
Context
)
error
GetAgentHealth
(
ctx
context
.
Context
)
service
.
DataManagementAgentHealth
}
type
TestS3ConnectionRequest
struct
{
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region" binding:"required"`
Bucket
string
`json:"bucket" binding:"required"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"secret_access_key"`
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
UseSSL
bool
`json:"use_ssl"`
}
type
CreateBackupJobRequest
struct
{
BackupType
string
`json:"backup_type" binding:"required,oneof=postgres redis full"`
UploadToS3
bool
`json:"upload_to_s3"`
S3ProfileID
string
`json:"s3_profile_id"`
PostgresID
string
`json:"postgres_profile_id"`
RedisID
string
`json:"redis_profile_id"`
IdempotencyKey
string
`json:"idempotency_key"`
}
type
CreateSourceProfileRequest
struct
{
ProfileID
string
`json:"profile_id" binding:"required"`
Name
string
`json:"name" binding:"required"`
Config
service
.
DataManagementSourceConfig
`json:"config" binding:"required"`
SetActive
bool
`json:"set_active"`
}
type
UpdateSourceProfileRequest
struct
{
Name
string
`json:"name" binding:"required"`
Config
service
.
DataManagementSourceConfig
`json:"config" binding:"required"`
}
type
CreateS3ProfileRequest
struct
{
ProfileID
string
`json:"profile_id" binding:"required"`
Name
string
`json:"name" binding:"required"`
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"secret_access_key"`
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
UseSSL
bool
`json:"use_ssl"`
SetActive
bool
`json:"set_active"`
}
type
UpdateS3ProfileRequest
struct
{
Name
string
`json:"name" binding:"required"`
Enabled
bool
`json:"enabled"`
Endpoint
string
`json:"endpoint"`
Region
string
`json:"region"`
Bucket
string
`json:"bucket"`
AccessKeyID
string
`json:"access_key_id"`
SecretAccessKey
string
`json:"secret_access_key"`
Prefix
string
`json:"prefix"`
ForcePathStyle
bool
`json:"force_path_style"`
UseSSL
bool
`json:"use_ssl"`
}
func
(
h
*
DataManagementHandler
)
GetAgentHealth
(
c
*
gin
.
Context
)
{
health
:=
h
.
getAgentHealth
(
c
)
payload
:=
gin
.
H
{
"enabled"
:
health
.
Enabled
,
"reason"
:
health
.
Reason
,
"socket_path"
:
health
.
SocketPath
,
}
if
health
.
Agent
!=
nil
{
payload
[
"agent"
]
=
gin
.
H
{
"status"
:
health
.
Agent
.
Status
,
"version"
:
health
.
Agent
.
Version
,
"uptime_seconds"
:
health
.
Agent
.
UptimeSeconds
,
}
}
response
.
Success
(
c
,
payload
)
}
func
(
h
*
DataManagementHandler
)
GetConfig
(
c
*
gin
.
Context
)
{
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
cfg
,
err
:=
h
.
dataManagementService
.
GetConfig
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
cfg
)
}
func
(
h
*
DataManagementHandler
)
UpdateConfig
(
c
*
gin
.
Context
)
{
var
req
service
.
DataManagementConfig
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
cfg
,
err
:=
h
.
dataManagementService
.
UpdateConfig
(
c
.
Request
.
Context
(),
req
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
cfg
)
}
func
(
h
*
DataManagementHandler
)
TestS3
(
c
*
gin
.
Context
)
{
var
req
TestS3ConnectionRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
result
,
err
:=
h
.
dataManagementService
.
ValidateS3
(
c
.
Request
.
Context
(),
service
.
DataManagementS3Config
{
Enabled
:
true
,
Endpoint
:
req
.
Endpoint
,
Region
:
req
.
Region
,
Bucket
:
req
.
Bucket
,
AccessKeyID
:
req
.
AccessKeyID
,
SecretAccessKey
:
req
.
SecretAccessKey
,
Prefix
:
req
.
Prefix
,
ForcePathStyle
:
req
.
ForcePathStyle
,
UseSSL
:
req
.
UseSSL
,
})
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"ok"
:
result
.
OK
,
"message"
:
result
.
Message
})
}
func
(
h
*
DataManagementHandler
)
CreateBackupJob
(
c
*
gin
.
Context
)
{
var
req
CreateBackupJobRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
req
.
IdempotencyKey
=
normalizeBackupIdempotencyKey
(
c
.
GetHeader
(
"X-Idempotency-Key"
),
req
.
IdempotencyKey
)
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
triggeredBy
:=
"admin:unknown"
if
subject
,
ok
:=
middleware2
.
GetAuthSubjectFromContext
(
c
);
ok
{
triggeredBy
=
"admin:"
+
strconv
.
FormatInt
(
subject
.
UserID
,
10
)
}
job
,
err
:=
h
.
dataManagementService
.
CreateBackupJob
(
c
.
Request
.
Context
(),
service
.
DataManagementCreateBackupJobInput
{
BackupType
:
req
.
BackupType
,
UploadToS3
:
req
.
UploadToS3
,
S3ProfileID
:
req
.
S3ProfileID
,
PostgresID
:
req
.
PostgresID
,
RedisID
:
req
.
RedisID
,
TriggeredBy
:
triggeredBy
,
IdempotencyKey
:
req
.
IdempotencyKey
,
})
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"job_id"
:
job
.
JobID
,
"status"
:
job
.
Status
})
}
func
(
h
*
DataManagementHandler
)
ListSourceProfiles
(
c
*
gin
.
Context
)
{
sourceType
:=
strings
.
TrimSpace
(
c
.
Param
(
"source_type"
))
if
sourceType
==
""
{
response
.
BadRequest
(
c
,
"Invalid source_type"
)
return
}
if
sourceType
!=
"postgres"
&&
sourceType
!=
"redis"
{
response
.
BadRequest
(
c
,
"source_type must be postgres or redis"
)
return
}
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
items
,
err
:=
h
.
dataManagementService
.
ListSourceProfiles
(
c
.
Request
.
Context
(),
sourceType
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"items"
:
items
})
}
func
(
h
*
DataManagementHandler
)
CreateSourceProfile
(
c
*
gin
.
Context
)
{
sourceType
:=
strings
.
TrimSpace
(
c
.
Param
(
"source_type"
))
if
sourceType
!=
"postgres"
&&
sourceType
!=
"redis"
{
response
.
BadRequest
(
c
,
"source_type must be postgres or redis"
)
return
}
var
req
CreateSourceProfileRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
profile
,
err
:=
h
.
dataManagementService
.
CreateSourceProfile
(
c
.
Request
.
Context
(),
service
.
DataManagementCreateSourceProfileInput
{
SourceType
:
sourceType
,
ProfileID
:
req
.
ProfileID
,
Name
:
req
.
Name
,
Config
:
req
.
Config
,
SetActive
:
req
.
SetActive
,
})
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
profile
)
}
func
(
h
*
DataManagementHandler
)
UpdateSourceProfile
(
c
*
gin
.
Context
)
{
sourceType
:=
strings
.
TrimSpace
(
c
.
Param
(
"source_type"
))
if
sourceType
!=
"postgres"
&&
sourceType
!=
"redis"
{
response
.
BadRequest
(
c
,
"source_type must be postgres or redis"
)
return
}
profileID
:=
strings
.
TrimSpace
(
c
.
Param
(
"profile_id"
))
if
profileID
==
""
{
response
.
BadRequest
(
c
,
"Invalid profile_id"
)
return
}
var
req
UpdateSourceProfileRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
profile
,
err
:=
h
.
dataManagementService
.
UpdateSourceProfile
(
c
.
Request
.
Context
(),
service
.
DataManagementUpdateSourceProfileInput
{
SourceType
:
sourceType
,
ProfileID
:
profileID
,
Name
:
req
.
Name
,
Config
:
req
.
Config
,
})
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
profile
)
}
func
(
h
*
DataManagementHandler
)
DeleteSourceProfile
(
c
*
gin
.
Context
)
{
sourceType
:=
strings
.
TrimSpace
(
c
.
Param
(
"source_type"
))
if
sourceType
!=
"postgres"
&&
sourceType
!=
"redis"
{
response
.
BadRequest
(
c
,
"source_type must be postgres or redis"
)
return
}
profileID
:=
strings
.
TrimSpace
(
c
.
Param
(
"profile_id"
))
if
profileID
==
""
{
response
.
BadRequest
(
c
,
"Invalid profile_id"
)
return
}
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
if
err
:=
h
.
dataManagementService
.
DeleteSourceProfile
(
c
.
Request
.
Context
(),
sourceType
,
profileID
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"deleted"
:
true
})
}
func
(
h
*
DataManagementHandler
)
SetActiveSourceProfile
(
c
*
gin
.
Context
)
{
sourceType
:=
strings
.
TrimSpace
(
c
.
Param
(
"source_type"
))
if
sourceType
!=
"postgres"
&&
sourceType
!=
"redis"
{
response
.
BadRequest
(
c
,
"source_type must be postgres or redis"
)
return
}
profileID
:=
strings
.
TrimSpace
(
c
.
Param
(
"profile_id"
))
if
profileID
==
""
{
response
.
BadRequest
(
c
,
"Invalid profile_id"
)
return
}
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
profile
,
err
:=
h
.
dataManagementService
.
SetActiveSourceProfile
(
c
.
Request
.
Context
(),
sourceType
,
profileID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
profile
)
}
func
(
h
*
DataManagementHandler
)
ListS3Profiles
(
c
*
gin
.
Context
)
{
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
items
,
err
:=
h
.
dataManagementService
.
ListS3Profiles
(
c
.
Request
.
Context
())
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"items"
:
items
})
}
func
(
h
*
DataManagementHandler
)
CreateS3Profile
(
c
*
gin
.
Context
)
{
var
req
CreateS3ProfileRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
profile
,
err
:=
h
.
dataManagementService
.
CreateS3Profile
(
c
.
Request
.
Context
(),
service
.
DataManagementCreateS3ProfileInput
{
ProfileID
:
req
.
ProfileID
,
Name
:
req
.
Name
,
SetActive
:
req
.
SetActive
,
S3
:
service
.
DataManagementS3Config
{
Enabled
:
req
.
Enabled
,
Endpoint
:
req
.
Endpoint
,
Region
:
req
.
Region
,
Bucket
:
req
.
Bucket
,
AccessKeyID
:
req
.
AccessKeyID
,
SecretAccessKey
:
req
.
SecretAccessKey
,
Prefix
:
req
.
Prefix
,
ForcePathStyle
:
req
.
ForcePathStyle
,
UseSSL
:
req
.
UseSSL
,
},
})
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
profile
)
}
func
(
h
*
DataManagementHandler
)
UpdateS3Profile
(
c
*
gin
.
Context
)
{
var
req
UpdateS3ProfileRequest
if
err
:=
c
.
ShouldBindJSON
(
&
req
);
err
!=
nil
{
response
.
BadRequest
(
c
,
"Invalid request: "
+
err
.
Error
())
return
}
profileID
:=
strings
.
TrimSpace
(
c
.
Param
(
"profile_id"
))
if
profileID
==
""
{
response
.
BadRequest
(
c
,
"Invalid profile_id"
)
return
}
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
profile
,
err
:=
h
.
dataManagementService
.
UpdateS3Profile
(
c
.
Request
.
Context
(),
service
.
DataManagementUpdateS3ProfileInput
{
ProfileID
:
profileID
,
Name
:
req
.
Name
,
S3
:
service
.
DataManagementS3Config
{
Enabled
:
req
.
Enabled
,
Endpoint
:
req
.
Endpoint
,
Region
:
req
.
Region
,
Bucket
:
req
.
Bucket
,
AccessKeyID
:
req
.
AccessKeyID
,
SecretAccessKey
:
req
.
SecretAccessKey
,
Prefix
:
req
.
Prefix
,
ForcePathStyle
:
req
.
ForcePathStyle
,
UseSSL
:
req
.
UseSSL
,
},
})
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
profile
)
}
func
(
h
*
DataManagementHandler
)
DeleteS3Profile
(
c
*
gin
.
Context
)
{
profileID
:=
strings
.
TrimSpace
(
c
.
Param
(
"profile_id"
))
if
profileID
==
""
{
response
.
BadRequest
(
c
,
"Invalid profile_id"
)
return
}
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
if
err
:=
h
.
dataManagementService
.
DeleteS3Profile
(
c
.
Request
.
Context
(),
profileID
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
gin
.
H
{
"deleted"
:
true
})
}
func
(
h
*
DataManagementHandler
)
SetActiveS3Profile
(
c
*
gin
.
Context
)
{
profileID
:=
strings
.
TrimSpace
(
c
.
Param
(
"profile_id"
))
if
profileID
==
""
{
response
.
BadRequest
(
c
,
"Invalid profile_id"
)
return
}
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
profile
,
err
:=
h
.
dataManagementService
.
SetActiveS3Profile
(
c
.
Request
.
Context
(),
profileID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
profile
)
}
func
(
h
*
DataManagementHandler
)
ListBackupJobs
(
c
*
gin
.
Context
)
{
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
pageSize
:=
int32
(
20
)
if
raw
:=
strings
.
TrimSpace
(
c
.
Query
(
"page_size"
));
raw
!=
""
{
v
,
err
:=
strconv
.
Atoi
(
raw
)
if
err
!=
nil
||
v
<=
0
{
response
.
BadRequest
(
c
,
"Invalid page_size"
)
return
}
pageSize
=
int32
(
v
)
}
result
,
err
:=
h
.
dataManagementService
.
ListBackupJobs
(
c
.
Request
.
Context
(),
service
.
DataManagementListBackupJobsInput
{
PageSize
:
pageSize
,
PageToken
:
c
.
Query
(
"page_token"
),
Status
:
c
.
Query
(
"status"
),
BackupType
:
c
.
Query
(
"backup_type"
),
})
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
result
)
}
func
(
h
*
DataManagementHandler
)
GetBackupJob
(
c
*
gin
.
Context
)
{
jobID
:=
strings
.
TrimSpace
(
c
.
Param
(
"job_id"
))
if
jobID
==
""
{
response
.
BadRequest
(
c
,
"Invalid backup job ID"
)
return
}
if
!
h
.
requireAgentEnabled
(
c
)
{
return
}
job
,
err
:=
h
.
dataManagementService
.
GetBackupJob
(
c
.
Request
.
Context
(),
jobID
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
response
.
Success
(
c
,
job
)
}
func
(
h
*
DataManagementHandler
)
requireAgentEnabled
(
c
*
gin
.
Context
)
bool
{
if
h
.
dataManagementService
==
nil
{
err
:=
infraerrors
.
ServiceUnavailable
(
service
.
DataManagementAgentUnavailableReason
,
"data management agent service is not configured"
,
)
.
WithMetadata
(
map
[
string
]
string
{
"socket_path"
:
service
.
DefaultDataManagementAgentSocketPath
})
response
.
ErrorFrom
(
c
,
err
)
return
false
}
if
err
:=
h
.
dataManagementService
.
EnsureAgentEnabled
(
c
.
Request
.
Context
());
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
false
}
return
true
}
func
(
h
*
DataManagementHandler
)
getAgentHealth
(
c
*
gin
.
Context
)
service
.
DataManagementAgentHealth
{
if
h
.
dataManagementService
==
nil
{
return
service
.
DataManagementAgentHealth
{
Enabled
:
false
,
Reason
:
service
.
DataManagementAgentUnavailableReason
,
SocketPath
:
service
.
DefaultDataManagementAgentSocketPath
,
}
}
return
h
.
dataManagementService
.
GetAgentHealth
(
c
.
Request
.
Context
())
}
func
normalizeBackupIdempotencyKey
(
headerValue
,
bodyValue
string
)
string
{
headerKey
:=
strings
.
TrimSpace
(
headerValue
)
if
headerKey
!=
""
{
return
headerKey
}
return
strings
.
TrimSpace
(
bodyValue
)
}
backend/internal/handler/admin/data_management_handler_test.go
0 → 100644
View file @
3d79773b
package
admin
import
(
"encoding/json"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type
apiEnvelope
struct
{
Code
int
`json:"code"`
Message
string
`json:"message"`
Reason
string
`json:"reason"`
Data
json
.
RawMessage
`json:"data"`
}
func
TestDataManagementHandler_AgentHealthAlways200
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
service
.
NewDataManagementServiceWithOptions
(
filepath
.
Join
(
t
.
TempDir
(),
"missing.sock"
),
50
*
time
.
Millisecond
)
h
:=
NewDataManagementHandler
(
svc
)
r
:=
gin
.
New
()
r
.
GET
(
"/api/v1/admin/data-management/agent/health"
,
h
.
GetAgentHealth
)
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/admin/data-management/agent/health"
,
nil
)
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
var
envelope
apiEnvelope
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
envelope
))
require
.
Equal
(
t
,
0
,
envelope
.
Code
)
var
data
struct
{
Enabled
bool
`json:"enabled"`
Reason
string
`json:"reason"`
SocketPath
string
`json:"socket_path"`
}
require
.
NoError
(
t
,
json
.
Unmarshal
(
envelope
.
Data
,
&
data
))
require
.
False
(
t
,
data
.
Enabled
)
require
.
Equal
(
t
,
service
.
DataManagementDeprecatedReason
,
data
.
Reason
)
require
.
Equal
(
t
,
svc
.
SocketPath
(),
data
.
SocketPath
)
}
func
TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
svc
:=
service
.
NewDataManagementServiceWithOptions
(
filepath
.
Join
(
t
.
TempDir
(),
"missing.sock"
),
50
*
time
.
Millisecond
)
h
:=
NewDataManagementHandler
(
svc
)
r
:=
gin
.
New
()
r
.
GET
(
"/api/v1/admin/data-management/config"
,
h
.
GetConfig
)
rec
:=
httptest
.
NewRecorder
()
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/admin/data-management/config"
,
nil
)
r
.
ServeHTTP
(
rec
,
req
)
require
.
Equal
(
t
,
http
.
StatusServiceUnavailable
,
rec
.
Code
)
var
envelope
apiEnvelope
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
envelope
))
require
.
Equal
(
t
,
http
.
StatusServiceUnavailable
,
envelope
.
Code
)
require
.
Equal
(
t
,
service
.
DataManagementDeprecatedReason
,
envelope
.
Reason
)
}
func
TestNormalizeBackupIdempotencyKey
(
t
*
testing
.
T
)
{
require
.
Equal
(
t
,
"from-header"
,
normalizeBackupIdempotencyKey
(
"from-header"
,
"from-body"
))
require
.
Equal
(
t
,
"from-body"
,
normalizeBackupIdempotencyKey
(
" "
,
" from-body "
))
require
.
Equal
(
t
,
""
,
normalizeBackupIdempotencyKey
(
""
,
""
))
}
backend/internal/handler/admin/gemini_oauth_handler.go
View file @
3d79773b
...
...
@@ -61,7 +61,11 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
if
err
!=
nil
{
msg
:=
err
.
Error
()
// Treat missing/invalid OAuth client configuration as a user/config error.
if
strings
.
Contains
(
msg
,
"OAuth client not configured"
)
||
strings
.
Contains
(
msg
,
"requires your own OAuth Client"
)
{
if
strings
.
Contains
(
msg
,
"OAuth client not configured"
)
||
strings
.
Contains
(
msg
,
"requires your own OAuth Client"
)
||
strings
.
Contains
(
msg
,
"requires a custom OAuth Client"
)
||
strings
.
Contains
(
msg
,
"GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING"
)
||
strings
.
Contains
(
msg
,
"built-in Gemini CLI OAuth client_secret is not configured"
)
{
response
.
BadRequest
(
c
,
"Failed to generate auth URL: "
+
msg
)
return
}
...
...
backend/internal/handler/admin/group_handler.go
View file @
3d79773b
...
...
@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
type
CreateGroupRequest
struct
{
Name
string
`json:"name" binding:"required"`
Description
string
`json:"description"`
Platform
string
`json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform
string
`json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity
sora
"`
RateMultiplier
float64
`json:"rate_multiplier"`
IsExclusive
bool
`json:"is_exclusive"`
SubscriptionType
string
`json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
...
...
@@ -38,6 +38,10 @@ type CreateGroupRequest struct {
ImagePrice1K
*
float64
`json:"image_price_1k"`
ImagePrice2K
*
float64
`json:"image_price_2k"`
ImagePrice4K
*
float64
`json:"image_price_4k"`
SoraImagePrice360
*
float64
`json:"sora_image_price_360"`
SoraImagePrice540
*
float64
`json:"sora_image_price_540"`
SoraVideoPricePerRequest
*
float64
`json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD
*
float64
`json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request"`
...
...
@@ -47,6 +51,8 @@ type CreateGroupRequest struct {
MCPXMLInject
*
bool
`json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
[]
string
`json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes
int64
`json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
}
...
...
@@ -55,7 +61,7 @@ type CreateGroupRequest struct {
type
UpdateGroupRequest
struct
{
Name
string
`json:"name"`
Description
string
`json:"description"`
Platform
string
`json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform
string
`json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity
sora
"`
RateMultiplier
*
float64
`json:"rate_multiplier"`
IsExclusive
*
bool
`json:"is_exclusive"`
Status
string
`json:"status" binding:"omitempty,oneof=active inactive"`
...
...
@@ -67,6 +73,10 @@ type UpdateGroupRequest struct {
ImagePrice1K
*
float64
`json:"image_price_1k"`
ImagePrice2K
*
float64
`json:"image_price_2k"`
ImagePrice4K
*
float64
`json:"image_price_4k"`
SoraImagePrice360
*
float64
`json:"sora_image_price_360"`
SoraImagePrice540
*
float64
`json:"sora_image_price_540"`
SoraVideoPricePerRequest
*
float64
`json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD
*
float64
`json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly
*
bool
`json:"claude_code_only"`
FallbackGroupID
*
int64
`json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest
*
int64
`json:"fallback_group_id_on_invalid_request"`
...
...
@@ -76,6 +86,8 @@ type UpdateGroupRequest struct {
MCPXMLInject
*
bool
`json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes
*
[]
string
`json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes
*
int64
`json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs
[]
int64
`json:"copy_accounts_from_group_ids"`
}
...
...
@@ -179,6 +191,10 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice1K
:
req
.
ImagePrice1K
,
ImagePrice2K
:
req
.
ImagePrice2K
,
ImagePrice4K
:
req
.
ImagePrice4K
,
SoraImagePrice360
:
req
.
SoraImagePrice360
,
SoraImagePrice540
:
req
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
req
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
req
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
req
.
ClaudeCodeOnly
,
FallbackGroupID
:
req
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
req
.
FallbackGroupIDOnInvalidRequest
,
...
...
@@ -186,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRoutingEnabled
:
req
.
ModelRoutingEnabled
,
MCPXMLInject
:
req
.
MCPXMLInject
,
SupportedModelScopes
:
req
.
SupportedModelScopes
,
SoraStorageQuotaBytes
:
req
.
SoraStorageQuotaBytes
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
})
if
err
!=
nil
{
...
...
@@ -225,6 +242,10 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice1K
:
req
.
ImagePrice1K
,
ImagePrice2K
:
req
.
ImagePrice2K
,
ImagePrice4K
:
req
.
ImagePrice4K
,
SoraImagePrice360
:
req
.
SoraImagePrice360
,
SoraImagePrice540
:
req
.
SoraImagePrice540
,
SoraVideoPricePerRequest
:
req
.
SoraVideoPricePerRequest
,
SoraVideoPricePerRequestHD
:
req
.
SoraVideoPricePerRequestHD
,
ClaudeCodeOnly
:
req
.
ClaudeCodeOnly
,
FallbackGroupID
:
req
.
FallbackGroupID
,
FallbackGroupIDOnInvalidRequest
:
req
.
FallbackGroupIDOnInvalidRequest
,
...
...
@@ -232,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRoutingEnabled
:
req
.
ModelRoutingEnabled
,
MCPXMLInject
:
req
.
MCPXMLInject
,
SupportedModelScopes
:
req
.
SupportedModelScopes
,
SoraStorageQuotaBytes
:
req
.
SoraStorageQuotaBytes
,
CopyAccountsFromGroupIDs
:
req
.
CopyAccountsFromGroupIDs
,
})
if
err
!=
nil
{
...
...
backend/internal/handler/admin/id_list_utils.go
0 → 100644
View file @
3d79773b
package
admin
import
"sort"
func
normalizeInt64IDList
(
ids
[]
int64
)
[]
int64
{
if
len
(
ids
)
==
0
{
return
nil
}
out
:=
make
([]
int64
,
0
,
len
(
ids
))
seen
:=
make
(
map
[
int64
]
struct
{},
len
(
ids
))
for
_
,
id
:=
range
ids
{
if
id
<=
0
{
continue
}
if
_
,
ok
:=
seen
[
id
];
ok
{
continue
}
seen
[
id
]
=
struct
{}{}
out
=
append
(
out
,
id
)
}
sort
.
Slice
(
out
,
func
(
i
,
j
int
)
bool
{
return
out
[
i
]
<
out
[
j
]
})
return
out
}
Prev
1
2
3
4
5
6
7
8
9
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment