Commit 3d79773b authored by kyx236's avatar kyx236
Browse files

Merge branch 'main' of https://github.com/james-6-23/sub2api

parents 6aa8cbbf 742e73c9
......@@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
return
}
dataPayload := req.Data
if err := validateDataHeader(dataPayload); err != nil {
if err := validateDataHeader(req.Data); err != nil {
response.BadRequest(c, err.Error())
return
}
executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
return h.importData(ctx, req)
})
}
func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) {
skipDefaultGroupBind := true
if req.SkipDefaultGroupBind != nil {
skipDefaultGroupBind = *req.SkipDefaultGroupBind
}
dataPayload := req.Data
result := DataImportResult{}
existingProxies, err := h.listAllProxies(c.Request.Context())
existingProxies, err := h.listAllProxies(ctx)
if err != nil {
response.ErrorFrom(c, err)
return
return result, err
}
proxyKeyToID := make(map[string]int64, len(existingProxies))
......@@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
proxyKeyToID[key] = existingID
result.ProxyReused++
if normalizedStatus != "" {
if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{
if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
......@@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
continue
}
created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: defaultProxyName(item.Name),
Protocol: item.Protocol,
Host: item.Host,
......@@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
Username: item.Username,
Password: item.Password,
})
if err != nil {
if createErr != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
Message: createErr.Error(),
})
continue
}
......@@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result.ProxyCreated++
if normalizedStatus != "" && normalizedStatus != created.Status {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{
_, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
......@@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
SkipDefaultGroupBind: skipDefaultGroupBind,
}
if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil {
if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
......@@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result.AccountCreated++
}
response.Success(c, result)
return result, nil
}
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {
......
......@@ -64,6 +64,7 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) {
nil,
nil,
nil,
nil,
)
router.GET("/api/v1/admin/accounts/data", h.ExportData)
......
......@@ -2,7 +2,13 @@
package admin
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
......@@ -10,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
......@@ -46,6 +53,7 @@ type AccountHandler struct {
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
sessionLimitCache service.SessionLimitCache
rpmCache service.RPMCache
tokenCacheInvalidator service.TokenCacheInvalidator
}
......@@ -62,6 +70,7 @@ func NewAccountHandler(
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
sessionLimitCache service.SessionLimitCache,
rpmCache service.RPMCache,
tokenCacheInvalidator service.TokenCacheInvalidator,
) *AccountHandler {
return &AccountHandler{
......@@ -76,6 +85,7 @@ func NewAccountHandler(
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
tokenCacheInvalidator: tokenCacheInvalidator,
}
}
......@@ -133,6 +143,13 @@ type BulkUpdateAccountsRequest struct {
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// CheckMixedChannelRequest represents check mixed channel risk request
type CheckMixedChannelRequest struct {
Platform string `json:"platform" binding:"required"`
GroupIDs []int64 `json:"group_ids"`
AccountID *int64 `json:"account_id"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct {
*dto.Account
......@@ -140,6 +157,51 @@ type AccountWithConcurrency struct {
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
}
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
item := AccountWithConcurrency{
Account: dto.AccountFromService(account),
CurrentConcurrency: 0,
}
if account == nil {
return item
}
if h.concurrencyService != nil {
if counts, err := h.concurrencyService.GetAccountConcurrencyBatch(ctx, []int64{account.ID}); err == nil {
item.CurrentConcurrency = counts[account.ID]
}
}
if account.IsAnthropicOAuthOrSetupToken() {
if h.accountUsageService != nil && account.GetWindowCostLimit() > 0 {
startTime := account.GetCurrentWindowStartTime()
if stats, err := h.accountUsageService.GetAccountWindowStats(ctx, account.ID, startTime); err == nil && stats != nil {
cost := stats.StandardCost
item.CurrentWindowCost = &cost
}
}
if h.sessionLimitCache != nil && account.GetMaxSessions() > 0 {
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
idleTimeouts := map[int64]time.Duration{account.ID: idleTimeout}
if sessions, err := h.sessionLimitCache.GetActiveSessionCountBatch(ctx, []int64{account.ID}, idleTimeouts); err == nil {
if count, ok := sessions[account.ID]; ok {
item.ActiveSessions = &count
}
}
}
if h.rpmCache != nil && account.GetBaseRPM() > 0 {
if rpm, err := h.rpmCache.GetRPM(ctx, account.ID); err == nil {
item.CurrentRPM = &rpm
}
}
}
return item
}
// List handles listing all accounts with pagination
......@@ -155,6 +217,7 @@ func (h *AccountHandler) List(c *gin.Context) {
if len(search) > 100 {
search = search[:100]
}
lite := parseBoolQueryWithDefault(c.Query("lite"), false)
var groupID int64
if groupIDStr := c.Query("group"); groupIDStr != "" {
......@@ -173,15 +236,21 @@ func (h *AccountHandler) List(c *gin.Context) {
accountIDs[i] = acc.ID
}
concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs)
if err != nil {
// Log error but don't fail the request, just use 0 for all
concurrencyCounts = make(map[int64]int)
concurrencyCounts := make(map[int64]int)
var windowCosts map[int64]float64
var activeSessions map[int64]int
var rpmCounts map[int64]int
if !lite {
// Get current concurrency counts for all accounts
if h.concurrencyService != nil {
if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil {
concurrencyCounts = cc
}
// 识别需要查询窗口费用会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
}
// 识别需要查询窗口费用会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
rpmAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
......@@ -193,12 +262,19 @@ func (h *AccountHandler) List(c *gin.Context) {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
if acc.GetBaseRPM() > 0 {
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
}
}
}
// 并行获取窗口费用和活跃会话数
var windowCosts map[int64]float64
var activeSessions map[int64]int
// 获取 RPM 计数(批量查询)
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
if rpmCounts == nil {
rpmCounts = make(map[int64]int)
}
}
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
......@@ -235,6 +311,7 @@ func (h *AccountHandler) List(c *gin.Context) {
}
_ = g.Wait()
}
}
// Build response with concurrency info
result := make([]AccountWithConcurrency, len(accounts))
......@@ -259,12 +336,84 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
// 添加 RPM 计数(仅当启用时)
if rpmCounts != nil {
if rpm, ok := rpmCounts[acc.ID]; ok {
item.CurrentRPM = &rpm
}
}
result[i] = item
}
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, lite)
if etag != "" {
c.Header("ETag", etag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), etag) {
c.Status(http.StatusNotModified)
return
}
}
response.Paginated(c, result, total, page, pageSize)
}
func buildAccountsListETag(
items []AccountWithConcurrency,
total int64,
page, pageSize int,
platform, accountType, status, search string,
lite bool,
) string {
payload := struct {
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Platform string `json:"platform"`
AccountType string `json:"type"`
Status string `json:"status"`
Search string `json:"search"`
Lite bool `json:"lite"`
Items []AccountWithConcurrency `json:"items"`
}{
Total: total,
Page: page,
PageSize: pageSize,
Platform: platform,
AccountType: accountType,
Status: status,
Search: search,
Lite: lite,
Items: items,
}
raw, err := json.Marshal(payload)
if err != nil {
return ""
}
sum := sha256.Sum256(raw)
return "\"" + hex.EncodeToString(sum[:]) + "\""
}
func ifNoneMatchMatched(ifNoneMatch, etag string) bool {
if etag == "" || ifNoneMatch == "" {
return false
}
for _, token := range strings.Split(ifNoneMatch, ",") {
candidate := strings.TrimSpace(token)
if candidate == "*" {
return true
}
if candidate == etag {
return true
}
if strings.HasPrefix(candidate, "W/") && strings.TrimPrefix(candidate, "W/") == etag {
return true
}
}
return false
}
// GetByID handles getting an account by ID
// GET /api/v1/admin/accounts/:id
func (h *AccountHandler) GetByID(c *gin.Context) {
......@@ -280,7 +429,51 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
// POST /api/v1/admin/accounts/check-mixed-channel
func (h *AccountHandler) CheckMixedChannel(c *gin.Context) {
var req CheckMixedChannelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.GroupIDs) == 0 {
response.Success(c, gin.H{"has_risk": false})
return
}
accountID := int64(0)
if req.AccountID != nil {
accountID = *req.AccountID
}
err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs)
if err != nil {
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
response.Success(c, gin.H{
"has_risk": true,
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
})
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"has_risk": false})
}
// Create handles creating a new account
......@@ -295,11 +488,14 @@ func (h *AccountHandler) Create(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: req.Name,
Notes: req.Notes,
Platform: req.Platform,
......@@ -315,30 +511,34 @@ func (h *AccountHandler) Create(c *gin.Context) {
AutoPauseOnExpired: req.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if execErr != nil {
return nil, execErr
}
return h.buildAccountResponseWithRuntime(ctx, account), nil
})
if err != nil {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
// 创建接口仅返回最小必要字段,详细信息由专门检查接口提供
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AccountFromService(account))
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}
// Update handles updating an account
......@@ -359,6 +559,8 @@ func (h *AccountHandler) Update(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
......@@ -383,17 +585,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
// 更新接口仅返回最小必要字段,详细信息由专门检查接口提供
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
......@@ -402,7 +597,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// Delete handles deleting an account
......@@ -660,7 +855,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
response.Success(c, dto.AccountFromService(updatedAccount))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
}
// GetStats handles getting account statistics
......@@ -718,7 +913,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
}
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// BatchCreate handles batch creating accounts
......@@ -732,7 +927,7 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
return
}
ctx := c.Request.Context()
executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
success := 0
failed := 0
results := make([]gin.H, 0, len(req.Accounts))
......@@ -748,6 +943,9 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
continue
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(item.Extra)
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
......@@ -783,10 +981,11 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
})
}
response.Success(c, gin.H{
return gin.H{
"success": success,
"failed": failed,
"results": results,
}, nil
})
}
......@@ -824,49 +1023,48 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
}
ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
// 阶段一:预验证所有账号存在,收集 credentials
type accountUpdate struct {
ID int64
Credentials map[string]any
}
updates := make([]accountUpdate, 0, len(req.AccountIDs))
for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
return
}
// Update credentials field
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials[req.Field] = req.Value
// Update account
updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials,
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
}
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
if err != nil {
// 阶段二:依次更新,返回每个账号的成功/失败明细,便于调用方重试
success := 0
failed := 0
successIDs := make([]int64, 0, len(updates))
failedIDs := make([]int64, 0, len(updates))
results := make([]gin.H, 0, len(updates))
for _, u := range updates {
updateInput := &service.UpdateAccountInput{Credentials: u.Credentials}
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
failed++
failedIDs = append(failedIDs, u.ID)
results = append(results, gin.H{
"account_id": accountID,
"account_id": u.ID,
"success": false,
"error": err.Error(),
})
continue
}
success++
successIDs = append(successIDs, u.ID)
results = append(results, gin.H{
"account_id": accountID,
"account_id": u.ID,
"success": true,
})
}
......@@ -874,6 +1072,8 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
response.Success(c, gin.H{
"success": success,
"failed": failed,
"success_ids": successIDs,
"failed_ids": failedIDs,
"results": results,
})
}
......@@ -890,6 +1090,8 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
......@@ -925,6 +1127,14 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
})
return
}
response.ErrorFrom(c, err)
return
}
......@@ -1109,7 +1319,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
return
}
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// GetTempUnschedulable handles getting temporary unschedulable status
......@@ -1173,6 +1389,57 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
response.Success(c, stats)
}
// BatchTodayStatsRequest 批量今日统计请求体。
type BatchTodayStatsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required"`
}
// GetBatchTodayStats 批量获取多个账号的今日统计。
// POST /api/v1/admin/accounts/today-stats/batch
func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
var req BatchTodayStatsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
accountIDs := normalizeInt64IDList(req.AccountIDs)
if len(accountIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
cacheKey := buildAccountTodayStatsBatchCacheKey(accountIDs)
if cached, ok := accountTodayStatsBatchCache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), accountIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
payload := gin.H{"stats": stats}
cached := accountTodayStatsBatchCache.Set(cacheKey, payload)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// SetSchedulableRequest represents the request body for setting schedulable status
type SetSchedulableRequest struct {
Schedulable bool `json:"schedulable"`
......@@ -1199,7 +1466,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
return
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// GetAvailableModels handles getting available models for an account
......@@ -1296,32 +1563,14 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
// Handle Antigravity accounts: return Claude + Gemini models
if account.Platform == service.PlatformAntigravity {
// Antigravity 支持 Claude 和部分 Gemini 模型
type UnifiedModel struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
}
var models []UnifiedModel
// 添加 Claude 模型
for _, m := range claude.DefaultModels {
models = append(models, UnifiedModel{
ID: m.ID,
Type: m.Type,
DisplayName: m.DisplayName,
})
}
// 添加 Gemini 3 系列模型用于测试
geminiTestModels := []UnifiedModel{
{ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash"},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview"},
// 直接复用 antigravity.DefaultModels(),与 /v1/models 端点保持同步
response.Success(c, antigravity.DefaultModels())
return
}
models = append(models, geminiTestModels...)
response.Success(c, models)
// Handle Sora accounts
if account.Platform == service.PlatformSora {
response.Success(c, service.DefaultSoraModels(nil))
return
}
......@@ -1532,3 +1781,22 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultAntigravityModelMapping)
}
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
func sanitizeExtraBaseRPM(extra map[string]any) {
if extra == nil {
return
}
raw, ok := extra["base_rpm"]
if !ok {
return
}
v := service.ParseExtraInt(raw)
if v < 0 {
v = 0
} else if v > 10000 {
v = 10000
}
extra["base_rpm"] = v
}
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
router.POST("/api/v1/admin/accounts", accountHandler.Create)
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate)
return router
}
func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, false, data["has_risk"])
require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID)
require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform)
require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs)
}
func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.checkMixedErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
"account_id": 99,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["has_risk"])
require.Equal(t, "mixed_channel_warning", data["error"])
details, ok := data["details"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(27), details["group_id"])
require.Equal(t, "claude-max", details["group_name"])
require.Equal(t, "Antigravity", details["current_platform"])
require.Equal(t, "Anthropic", details["other_platform"])
require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID)
}
func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.createAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"name": "ag-oauth-1",
"platform": "antigravity",
"type": "oauth",
"credentials": map[string]any{"refresh_token": "rt"},
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.updateAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
func TestAccountHandlerBulkUpdateMixedChannelConflict(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.bulkUpdateAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1, 2, 3},
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "claude-max")
}
func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1, 2},
"group_ids": []int64{27},
"confirm_mixed_channel_risk": true,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(2), data["success"])
require.Equal(t, float64(0), data["failed"])
}
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testing.T) {
gin.SetMode(gin.TestMode)
adminSvc := newStubAdminService()
handler := NewAccountHandler(
adminSvc,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
router := gin.New()
router.POST("/api/v1/admin/accounts", handler.Create)
body := map[string]any{
"name": "anthropic-key-1",
"platform": "anthropic",
"type": "apikey",
"credentials": map[string]any{
"api_key": "sk-ant-xxx",
"base_url": "https://api.anthropic.com",
},
"extra": map[string]any{
"anthropic_passthrough": true,
},
"concurrency": 1,
"priority": 1,
}
raw, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Len(t, adminSvc.createdAccounts, 1)
created := adminSvc.createdAccounts[0]
require.Equal(t, "anthropic", created.Platform)
require.Equal(t, "apikey", created.Type)
require.NotNil(t, created.Extra)
require.Equal(t, true, created.Extra["anthropic_passthrough"])
}
package admin
import (
"strconv"
"strings"
"time"
)
var accountTodayStatsBatchCache = newSnapshotCache(30 * time.Second)
func buildAccountTodayStatsBatchCacheKey(accountIDs []int64) string {
if len(accountIDs) == 0 {
return "accounts_today_stats_empty"
}
var b strings.Builder
b.Grow(len(accountIDs) * 6)
_, _ = b.WriteString("accounts_today_stats:")
for i, id := range accountIDs {
if i > 0 {
_ = b.WriteByte(',')
}
_, _ = b.WriteString(strconv.FormatInt(id, 10))
}
return b.String()
}
......@@ -19,7 +19,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
userHandler := NewUserHandler(adminSvc, nil)
groupHandler := NewGroupHandler(adminSvc)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc, nil)
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
......@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
......@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
router.ServeHTTP(rec, req)
......
......@@ -58,6 +58,96 @@ func TestParseOpsDuration(t *testing.T) {
require.False(t, ok)
}
func TestParseOpsOpenAITokenStatsDuration(t *testing.T) {
tests := []struct {
input string
want time.Duration
ok bool
}{
{input: "30m", want: 30 * time.Minute, ok: true},
{input: "1h", want: time.Hour, ok: true},
{input: "1d", want: 24 * time.Hour, ok: true},
{input: "15d", want: 15 * 24 * time.Hour, ok: true},
{input: "30d", want: 30 * 24 * time.Hour, ok: true},
{input: "7d", want: 0, ok: false},
}
for _, tt := range tests {
got, ok := parseOpsOpenAITokenStatsDuration(tt.input)
require.Equal(t, tt.ok, ok, "input=%s", tt.input)
require.Equal(t, tt.want, got, "input=%s", tt.input)
}
}
func TestParseOpsOpenAITokenStatsFilter_Defaults(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
before := time.Now().UTC()
filter, err := parseOpsOpenAITokenStatsFilter(c)
after := time.Now().UTC()
require.NoError(t, err)
require.NotNil(t, filter)
require.Equal(t, "30d", filter.TimeRange)
require.Equal(t, 1, filter.Page)
require.Equal(t, 20, filter.PageSize)
require.Equal(t, 0, filter.TopN)
require.Nil(t, filter.GroupID)
require.Equal(t, "", filter.Platform)
require.True(t, filter.StartTime.Before(filter.EndTime))
require.WithinDuration(t, before.Add(-30*24*time.Hour), filter.StartTime, 2*time.Second)
require.WithinDuration(t, after, filter.EndTime, 2*time.Second)
}
func TestParseOpsOpenAITokenStatsFilter_WithTopN(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(
http.MethodGet,
"/?time_range=1h&platform=openai&group_id=12&top_n=50",
nil,
)
filter, err := parseOpsOpenAITokenStatsFilter(c)
require.NoError(t, err)
require.Equal(t, "1h", filter.TimeRange)
require.Equal(t, "openai", filter.Platform)
require.NotNil(t, filter.GroupID)
require.Equal(t, int64(12), *filter.GroupID)
require.Equal(t, 50, filter.TopN)
require.Equal(t, 0, filter.Page)
require.Equal(t, 0, filter.PageSize)
}
func TestParseOpsOpenAITokenStatsFilter_InvalidParams(t *testing.T) {
tests := []string{
"/?time_range=7d",
"/?group_id=0",
"/?group_id=abc",
"/?top_n=0",
"/?top_n=101",
"/?top_n=10&page=1",
"/?top_n=10&page_size=20",
"/?page=0",
"/?page_size=0",
"/?page_size=101",
}
gin.SetMode(gin.TestMode)
for _, rawURL := range tests {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, rawURL, nil)
_, err := parseOpsOpenAITokenStatsFilter(c)
require.Error(t, err, "url=%s", rawURL)
}
}
func TestParseOpsTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
......
......@@ -22,6 +22,15 @@ type stubAdminService struct {
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
createAccountErr error
updateAccountErr error
bulkUpdateAccountErr error
checkMixedErr error
lastMixedCheck struct {
accountID int64
platform string
groupIDs []int64
}
mu sync.Mutex
}
......@@ -188,11 +197,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre
s.mu.Lock()
s.createdAccounts = append(s.createdAccounts, input)
s.mu.Unlock()
if s.createAccountErr != nil {
return nil, s.createAccountErr
}
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
if s.updateAccountErr != nil {
return nil, s.updateAccountErr
}
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
......@@ -221,7 +236,17 @@ func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64,
}
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
if s.bulkUpdateAccountErr != nil {
return nil, s.bulkUpdateAccountErr
}
return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil
}
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
s.lastMixedCheck.accountID = currentAccountID
s.lastMixedCheck.platform = currentAccountPlatform
s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...)
return s.checkMixedErr
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
......@@ -327,6 +352,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
}
func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
return &service.ProxyQualityCheckResult{
ProxyID: id,
Score: 95,
Grade: "A",
Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
PassedCount: 5,
WarnCount: 0,
FailedCount: 0,
ChallengeCount: 0,
CheckedAt: time.Now().Unix(),
Items: []service.ProxyQualityCheckItem{
{Target: "base_connectivity", Status: "pass", Message: "ok"},
{Target: "openai", Status: "pass", HTTPStatus: 401},
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
{Target: "gemini", Status: "pass", HTTPStatus: 200},
{Target: "sora", Status: "pass", HTTPStatus: 401},
},
}, nil
}
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
return s.redeems, int64(len(s.redeems)), nil
}
......@@ -361,5 +407,23 @@ func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []
return nil
}
func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
for i := range s.apiKeys {
if s.apiKeys[i].ID == keyID {
k := s.apiKeys[i]
if groupID != nil {
if *groupID == 0 {
k.GroupID = nil
} else {
gid := *groupID
k.GroupID = &gid
}
}
return &service.AdminUpdateAPIKeyGroupIDResult{APIKey: &k}, nil
}
}
return nil, service.ErrAPIKeyNotFound
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AdminAPIKeyHandler handles admin API key management
type AdminAPIKeyHandler struct {
adminService service.AdminService
}
// NewAdminAPIKeyHandler creates a new admin API key handler
func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandler {
return &AdminAPIKeyHandler{
adminService: adminService,
}
}
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
type AdminUpdateAPIKeyGroupRequest struct {
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
}
// UpdateGroup handles updating an API key's group binding
// PUT /api/v1/admin/api-keys/:id
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid API key ID")
return
}
var req AdminUpdateAPIKeyGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := struct {
APIKey *dto.APIKey `json:"api_key"`
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
GrantedGroupID *int64 `json:"granted_group_id,omitempty"`
GrantedGroupName string `json:"granted_group_name,omitempty"`
}{
APIKey: dto.APIKeyFromService(result.APIKey),
AutoGrantedGroupAccess: result.AutoGrantedGroupAccess,
GrantedGroupID: result.GrantedGroupID,
GrantedGroupName: result.GrantedGroupName,
}
response.Success(c, resp)
}
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAPIKeyHandler(adminSvc service.AdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
h := NewAdminAPIKeyHandler(adminSvc)
router.PUT("/api/v1/admin/api-keys/:id", h.UpdateGroup)
return router
}
func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/abc", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid API key ID")
}
func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{bad json`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid request")
}
func TestAdminAPIKeyHandler_UpdateGroup_KeyNotFound(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/999", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
// ErrAPIKeyNotFound maps to 404
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestAdminAPIKeyHandler_UpdateGroup_BindGroup(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Code int `json:"code"`
Data json.RawMessage `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
var data struct {
APIKey struct {
ID int64 `json:"id"`
GroupID *int64 `json:"group_id"`
} `json:"api_key"`
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
}
require.NoError(t, json.Unmarshal(resp.Data, &data))
require.Equal(t, int64(10), data.APIKey.ID)
require.NotNil(t, data.APIKey.GroupID)
require.Equal(t, int64(2), *data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
svc := newStubAdminService()
gid := int64(2)
svc.apiKeys[0].GroupID = &gid
router := setupAPIKeyHandler(svc)
body := `{"group_id": 0}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data struct {
APIKey struct {
GroupID *int64 `json:"group_id"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Nil(t, resp.Data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: errors.New("internal failure"),
}
router := setupAPIKeyHandler(svc)
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// H2: empty body → group_id is nil → no-op, returns original key
func TestAdminAPIKeyHandler_UpdateGroup_EmptyBody_NoChange(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Code int `json:"code"`
Data struct {
APIKey struct {
ID int64 `json:"id"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, int64(10), resp.Data.APIKey.ID)
}
// M2: service returns GROUP_NOT_ACTIVE → handler maps to 400
func TestAdminAPIKeyHandler_UpdateGroup_GroupNotActive(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active"),
}
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": 5}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "GROUP_NOT_ACTIVE")
}
// M2: service returns INVALID_GROUP_ID → handler maps to 400
func TestAdminAPIKeyHandler_UpdateGroup_NegativeGroupID(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative"),
}
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": -5}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "INVALID_GROUP_ID")
}
// failingUpdateGroupService overrides AdminUpdateAPIKeyGroupID to return an error.
type failingUpdateGroupService struct {
*stubAdminService
err error
}
func (f *failingUpdateGroupService) AdminUpdateAPIKeyGroupID(_ context.Context, _ int64, _ *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
return nil, f.err
}
//go:build unit
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。
type failingAdminService struct {
*stubAdminService
failOnAccountID int64
updateCallCount atomic.Int64
}
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
f.updateCallCount.Add(1)
if id == f.failOnAccountID {
return nil, errors.New("database error")
}
return f.stubAdminService.UpdateAccount(ctx, id, input)
}
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test-uuid",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
}
func TestBatchUpdateCredentials_PartialFailure(t *testing.T) {
// 让第 2 个账号(ID=2)更新时失败
svc := &failingAdminService{
stubAdminService: newStubAdminService(),
failOnAccountID: 2,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "org_uuid",
Value: "test-org",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
// 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细
require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细")
var resp map[string]any
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
data := resp["data"].(map[string]any)
require.Equal(t, float64(2), data["success"], "应有 2 个成功")
require.Equal(t, float64(1), data["failed"], "应有 1 个失败")
// 所有 3 个账号都会被尝试更新(非 fail-fast)
require.Equal(t, int64(3), svc.updateCallCount.Load(),
"应调用 3 次 UpdateAccount(逐个尝试,失败后继续)")
}
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
svc := &getAccountFailingService{
stubAdminService: newStubAdminService(),
failOnAccountID: 1,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
}
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
type getAccountFailingService struct {
*stubAdminService
failOnAccountID int64
}
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
if id == f.failOnAccountID {
return nil, errors.New("not found")
}
return f.stubAdminService.GetAccount(ctx, id)
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// intercept_warmup_requests 传入非 bool 类型(string),应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": "not-a-bool",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"intercept_warmup_requests 传入非 bool 值应返回 400")
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": true,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"intercept_warmup_requests 传入合法 bool 值应返回 200")
}
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入非 string 类型(number),应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": 12345,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"account_uuid 传入非 string 值应返回 400")
}
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入 null(设置为空),应正常通过
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": nil,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"account_uuid 传入 null 应返回 200")
}
package admin
import (
"encoding/json"
"errors"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
......@@ -186,7 +188,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
......@@ -194,6 +196,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var model string
var requestType *int16
var stream *bool
var billingType *int8
......@@ -220,9 +223,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
if modelStr := c.Query("model"); modelStr != "" {
model = modelStr
}
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
......@@ -235,7 +249,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
......@@ -251,12 +265,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var requestType *int16
var stream *bool
var billingType *int8
......@@ -280,9 +295,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
groupID = id
}
}
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
......@@ -295,7 +321,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
......@@ -308,6 +334,76 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
})
}
// GetGroupStats handles getting group usage statistics
// GET /api/v1/admin/dashboard/groups
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
var userID, apiKeyID, accountID, groupID int64
var requestType *int16
var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get group statistics")
return
}
response.Success(c, gin.H{
"groups": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}
// GetAPIKeyUsageTrend handles getting API key usage trend data
// GET /api/v1/admin/dashboard/api-keys-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
......@@ -365,6 +461,9 @@ type BatchUsersUsageRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second)
var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second)
// GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
......@@ -374,18 +473,34 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
if len(req.UserIDs) == 0 {
userIDs := normalizeInt64IDList(req.UserIDs)
if len(userIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
keyRaw, _ := json.Marshal(struct {
UserIDs []int64 `json:"user_ids"`
}{
UserIDs: userIDs,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardBatchUsersUsageCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), userIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
payload := gin.H{"stats": stats}
dashboardBatchUsersUsageCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
......@@ -402,16 +517,32 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return
}
if len(req.APIKeyIDs) == 0 {
apiKeyIDs := normalizeInt64IDList(req.APIKeyIDs)
if len(apiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
keyRaw, _ := json.Marshal(struct {
APIKeyIDs []int64 `json:"api_key_ids"`
}{
APIKeyIDs: apiKeyIDs,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardBatchAPIKeysUsageCache.Get(cacheKey); ok {
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), apiKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
payload := gin.H{"stats": stats}
dashboardBatchAPIKeysUsageCache.Set(cacheKey, payload)
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, payload)
}
package admin
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dashboardUsageRepoCapture struct {
service.UsageLogRepository
trendRequestType *int16
trendStream *bool
modelRequestType *int16
modelStream *bool
}
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, error) {
s.trendRequestType = requestType
s.trendStream = stream
return []usagestats.TrendDataPoint{}, nil
}
func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.ModelStat, error) {
s.modelRequestType = requestType
s.modelStream = stream
return []usagestats.ModelStat{}, nil
}
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
handler := NewDashboardHandler(dashboardSvc, nil)
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
router.GET("/admin/dashboard/models", handler.GetModelStats)
return router
}
func TestDashboardTrendRequestTypePriority(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.trendRequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType)
require.Nil(t, repo.trendStream)
}
func TestDashboardTrendInvalidRequestType(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardTrendInvalidStream(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsRequestTypePriority(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.modelRequestType)
require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType)
require.Nil(t, repo.modelStream)
}
func TestDashboardModelStatsInvalidRequestType(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsInvalidStream(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
package admin
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
var dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
type dashboardSnapshotV2Stats struct {
usagestats.DashboardStats
Uptime int64 `json:"uptime"`
}
type dashboardSnapshotV2Response struct {
GeneratedAt string `json:"generated_at"`
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
Granularity string `json:"granularity"`
Stats *dashboardSnapshotV2Stats `json:"stats,omitempty"`
Trend []usagestats.TrendDataPoint `json:"trend,omitempty"`
Models []usagestats.ModelStat `json:"models,omitempty"`
Groups []usagestats.GroupStat `json:"groups,omitempty"`
UsersTrend []usagestats.UserUsageTrendPoint `json:"users_trend,omitempty"`
}
type dashboardSnapshotV2Filters struct {
UserID int64
APIKeyID int64
AccountID int64
GroupID int64
Model string
RequestType *int16
Stream *bool
BillingType *int8
}
type dashboardSnapshotV2CacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Granularity string `json:"granularity"`
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Model string `json:"model"`
RequestType *int16 `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
IncludeStats bool `json:"include_stats"`
IncludeTrend bool `json:"include_trend"`
IncludeModels bool `json:"include_models"`
IncludeGroups bool `json:"include_groups"`
IncludeUsersTrend bool `json:"include_users_trend"`
UsersTrendLimit int `json:"users_trend_limit"`
}
func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := strings.TrimSpace(c.DefaultQuery("granularity", "day"))
if granularity != "hour" {
granularity = "day"
}
includeStats := parseBoolQueryWithDefault(c.Query("include_stats"), true)
includeTrend := parseBoolQueryWithDefault(c.Query("include_trend"), true)
includeModels := parseBoolQueryWithDefault(c.Query("include_model_stats"), true)
includeGroups := parseBoolQueryWithDefault(c.Query("include_group_stats"), false)
includeUsersTrend := parseBoolQueryWithDefault(c.Query("include_users_trend"), false)
usersTrendLimit := 12
if raw := strings.TrimSpace(c.Query("users_trend_limit")); raw != "" {
if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 && parsed <= 50 {
usersTrendLimit = parsed
}
}
filters, err := parseDashboardSnapshotV2Filters(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
keyRaw, _ := json.Marshal(dashboardSnapshotV2CacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Granularity: granularity,
UserID: filters.UserID,
APIKeyID: filters.APIKeyID,
AccountID: filters.AccountID,
GroupID: filters.GroupID,
Model: filters.Model,
RequestType: filters.RequestType,
Stream: filters.Stream,
BillingType: filters.BillingType,
IncludeStats: includeStats,
IncludeTrend: includeTrend,
IncludeModels: includeModels,
IncludeGroups: includeGroups,
IncludeUsersTrend: includeUsersTrend,
UsersTrendLimit: usersTrendLimit,
})
cacheKey := string(keyRaw)
if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
resp := &dashboardSnapshotV2Response{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
StartDate: startTime.Format("2006-01-02"),
EndDate: endTime.Add(-24 * time.Hour).Format("2006-01-02"),
Granularity: granularity,
}
if includeStats {
stats, err := h.dashboardService.GetDashboardStats(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics")
return
}
resp.Stats = &dashboardSnapshotV2Stats{
DashboardStats: *stats,
Uptime: int64(time.Since(h.startTime).Seconds()),
}
}
if includeTrend {
trend, err := h.dashboardService.GetUsageTrendWithFilters(
c.Request.Context(),
startTime,
endTime,
granularity,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.Model,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
}
resp.Trend = trend
}
if includeModels {
models, err := h.dashboardService.GetModelStatsWithFilters(
c.Request.Context(),
startTime,
endTime,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
}
resp.Models = models
}
if includeGroups {
groups, err := h.dashboardService.GetGroupStatsWithFilters(
c.Request.Context(),
startTime,
endTime,
filters.UserID,
filters.APIKeyID,
filters.AccountID,
filters.GroupID,
filters.RequestType,
filters.Stream,
filters.BillingType,
)
if err != nil {
response.Error(c, 500, "Failed to get group statistics")
return
}
resp.Groups = groups
}
if includeUsersTrend {
usersTrend, err := h.dashboardService.GetUserUsageTrend(
c.Request.Context(),
startTime,
endTime,
granularity,
usersTrendLimit,
)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
}
resp.UsersTrend = usersTrend
}
cached := dashboardSnapshotV2Cache.Set(cacheKey, resp)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, resp)
}
func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) {
filters := &dashboardSnapshotV2Filters{
Model: strings.TrimSpace(c.Query("model")),
}
if userIDStr := strings.TrimSpace(c.Query("user_id")); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.UserID = id
}
if apiKeyIDStr := strings.TrimSpace(c.Query("api_key_id")); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.APIKeyID = id
}
if accountIDStr := strings.TrimSpace(c.Query("account_id")); accountIDStr != "" {
id, err := strconv.ParseInt(accountIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.AccountID = id
}
if groupIDStr := strings.TrimSpace(c.Query("group_id")); groupIDStr != "" {
id, err := strconv.ParseInt(groupIDStr, 10, 64)
if err != nil {
return nil, err
}
filters.GroupID = id
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
return nil, err
}
value := int16(parsed)
filters.RequestType = &value
} else if streamStr := strings.TrimSpace(c.Query("stream")); streamStr != "" {
streamVal, err := strconv.ParseBool(streamStr)
if err != nil {
return nil, err
}
filters.Stream = &streamVal
}
if billingTypeStr := strings.TrimSpace(c.Query("billing_type")); billingTypeStr != "" {
v, err := strconv.ParseInt(billingTypeStr, 10, 8)
if err != nil {
return nil, err
}
bt := int8(v)
filters.BillingType = &bt
}
return filters, nil
}
package admin
import (
"context"
"strconv"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type DataManagementHandler struct {
dataManagementService dataManagementService
}
func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
return &DataManagementHandler{dataManagementService: dataManagementService}
}
type dataManagementService interface {
GetConfig(ctx context.Context) (service.DataManagementConfig, error)
UpdateConfig(ctx context.Context, cfg service.DataManagementConfig) (service.DataManagementConfig, error)
ValidateS3(ctx context.Context, cfg service.DataManagementS3Config) (service.DataManagementTestS3Result, error)
CreateBackupJob(ctx context.Context, input service.DataManagementCreateBackupJobInput) (service.DataManagementBackupJob, error)
ListSourceProfiles(ctx context.Context, sourceType string) ([]service.DataManagementSourceProfile, error)
CreateSourceProfile(ctx context.Context, input service.DataManagementCreateSourceProfileInput) (service.DataManagementSourceProfile, error)
UpdateSourceProfile(ctx context.Context, input service.DataManagementUpdateSourceProfileInput) (service.DataManagementSourceProfile, error)
DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error
SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (service.DataManagementSourceProfile, error)
ListS3Profiles(ctx context.Context) ([]service.DataManagementS3Profile, error)
CreateS3Profile(ctx context.Context, input service.DataManagementCreateS3ProfileInput) (service.DataManagementS3Profile, error)
UpdateS3Profile(ctx context.Context, input service.DataManagementUpdateS3ProfileInput) (service.DataManagementS3Profile, error)
DeleteS3Profile(ctx context.Context, profileID string) error
SetActiveS3Profile(ctx context.Context, profileID string) (service.DataManagementS3Profile, error)
ListBackupJobs(ctx context.Context, input service.DataManagementListBackupJobsInput) (service.DataManagementListBackupJobsResult, error)
GetBackupJob(ctx context.Context, jobID string) (service.DataManagementBackupJob, error)
EnsureAgentEnabled(ctx context.Context) error
GetAgentHealth(ctx context.Context) service.DataManagementAgentHealth
}
type TestS3ConnectionRequest struct {
Endpoint string `json:"endpoint"`
Region string `json:"region" binding:"required"`
Bucket string `json:"bucket" binding:"required"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
type CreateBackupJobRequest struct {
BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"`
UploadToS3 bool `json:"upload_to_s3"`
S3ProfileID string `json:"s3_profile_id"`
PostgresID string `json:"postgres_profile_id"`
RedisID string `json:"redis_profile_id"`
IdempotencyKey string `json:"idempotency_key"`
}
type CreateSourceProfileRequest struct {
ProfileID string `json:"profile_id" binding:"required"`
Name string `json:"name" binding:"required"`
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
SetActive bool `json:"set_active"`
}
type UpdateSourceProfileRequest struct {
Name string `json:"name" binding:"required"`
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
}
type CreateS3ProfileRequest struct {
ProfileID string `json:"profile_id" binding:"required"`
Name string `json:"name" binding:"required"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
SetActive bool `json:"set_active"`
}
type UpdateS3ProfileRequest struct {
Name string `json:"name" binding:"required"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) {
health := h.getAgentHealth(c)
payload := gin.H{
"enabled": health.Enabled,
"reason": health.Reason,
"socket_path": health.SocketPath,
}
if health.Agent != nil {
payload["agent"] = gin.H{
"status": health.Agent.Status,
"version": health.Agent.Version,
"uptime_seconds": health.Agent.UptimeSeconds,
}
}
response.Success(c, payload)
}
func (h *DataManagementHandler) GetConfig(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
cfg, err := h.dataManagementService.GetConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
var req service.DataManagementConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *DataManagementHandler) TestS3(c *gin.Context) {
var req TestS3ConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
Enabled: true,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"ok": result.OK, "message": result.Message})
}
func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
var req CreateBackupJobRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey)
if !h.requireAgentEnabled(c) {
return
}
triggeredBy := "admin:unknown"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
BackupType: req.BackupType,
UploadToS3: req.UploadToS3,
S3ProfileID: req.S3ProfileID,
PostgresID: req.PostgresID,
RedisID: req.RedisID,
TriggeredBy: triggeredBy,
IdempotencyKey: req.IdempotencyKey,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status})
}
func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType == "" {
response.BadRequest(c, "Invalid source_type")
return
}
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
if !h.requireAgentEnabled(c) {
return
}
items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"items": items})
}
func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
var req CreateSourceProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
SourceType: sourceType,
ProfileID: req.ProfileID,
Name: req.Name,
Config: req.Config,
SetActive: req.SetActive,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
var req UpdateSourceProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
SourceType: sourceType,
ProfileID: profileID,
Name: req.Name,
Config: req.Config,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
items, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"items": items})
}
func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
var req CreateS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
ProfileID: req.ProfileID,
Name: req.Name,
SetActive: req.SetActive,
S3: service.DataManagementS3Config{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
},
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
var req UpdateS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
ProfileID: profileID,
Name: req.Name,
S3: service.DataManagementS3Config{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
},
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
pageSize := int32(20)
if raw := strings.TrimSpace(c.Query("page_size")); raw != "" {
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
response.BadRequest(c, "Invalid page_size")
return
}
pageSize = int32(v)
}
result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
PageSize: pageSize,
PageToken: c.Query("page_token"),
Status: c.Query("status"),
BackupType: c.Query("backup_type"),
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
jobID := strings.TrimSpace(c.Param("job_id"))
if jobID == "" {
response.BadRequest(c, "Invalid backup job ID")
return
}
if !h.requireAgentEnabled(c) {
return
}
job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, job)
}
func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
if h.dataManagementService == nil {
err := infraerrors.ServiceUnavailable(
service.DataManagementAgentUnavailableReason,
"data management agent service is not configured",
).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath})
response.ErrorFrom(c, err)
return false
}
if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return false
}
return true
}
func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth {
if h.dataManagementService == nil {
return service.DataManagementAgentHealth{
Enabled: false,
Reason: service.DataManagementAgentUnavailableReason,
SocketPath: service.DefaultDataManagementAgentSocketPath,
}
}
return h.dataManagementService.GetAgentHealth(c.Request.Context())
}
func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string {
headerKey := strings.TrimSpace(headerValue)
if headerKey != "" {
return headerKey
}
return strings.TrimSpace(bodyValue)
}
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type apiEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason"`
Data json.RawMessage `json:"data"`
}
func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
h := NewDataManagementHandler(svc)
r := gin.New()
r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil)
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var envelope apiEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
require.Equal(t, 0, envelope.Code)
var data struct {
Enabled bool `json:"enabled"`
Reason string `json:"reason"`
SocketPath string `json:"socket_path"`
}
require.NoError(t, json.Unmarshal(envelope.Data, &data))
require.False(t, data.Enabled)
require.Equal(t, service.DataManagementDeprecatedReason, data.Reason)
require.Equal(t, svc.SocketPath(), data.SocketPath)
}
func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
h := NewDataManagementHandler(svc)
r := gin.New()
r.GET("/api/v1/admin/data-management/config", h.GetConfig)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil)
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
var envelope apiEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
require.Equal(t, http.StatusServiceUnavailable, envelope.Code)
require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason)
}
func TestNormalizeBackupIdempotencyKey(t *testing.T) {
require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body"))
require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body "))
require.Equal(t, "", normalizeBackupIdempotencyKey("", ""))
}
......@@ -61,7 +61,11 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
if err != nil {
msg := err.Error()
// Treat missing/invalid OAuth client configuration as a user/config error.
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
if strings.Contains(msg, "OAuth client not configured") ||
strings.Contains(msg, "requires your own OAuth Client") ||
strings.Contains(msg, "requires a custom OAuth Client") ||
strings.Contains(msg, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING") ||
strings.Contains(msg, "built-in Gemini CLI OAuth client_secret is not configured") {
response.BadRequest(c, "Failed to generate auth URL: "+msg)
return
}
......
......@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
......@@ -38,6 +38,10 @@ type CreateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
......@@ -47,6 +51,8 @@ type CreateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
......@@ -55,7 +61,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
......@@ -67,6 +73,10 @@ type UpdateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
......@@ -76,6 +86,8 @@ type UpdateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
......@@ -179,6 +191,10 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
......@@ -186,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
......@@ -225,6 +242,10 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
......@@ -232,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
......
package admin
import "sort"
func normalizeInt64IDList(ids []int64) []int64 {
if len(ids) == 0 {
return nil
}
out := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
return out
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment