"backend/internal/git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "c4615a122436a65741dc148a107b682f266d493c"
Unverified Commit 6bccb8a8 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge branch 'main' into feature/antigravity-user-agent-configurable

parents 1fc6ef3d 3de1e0e4
//go:build unit
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。
type failingAdminService struct {
*stubAdminService
failOnAccountID int64
updateCallCount atomic.Int64
}
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
f.updateCallCount.Add(1)
if id == f.failOnAccountID {
return nil, errors.New("database error")
}
return f.stubAdminService.UpdateAccount(ctx, id, input)
}
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test-uuid",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
}
func TestBatchUpdateCredentials_PartialFailure(t *testing.T) {
// 让第 2 个账号(ID=2)更新时失败
svc := &failingAdminService{
stubAdminService: newStubAdminService(),
failOnAccountID: 2,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "org_uuid",
Value: "test-org",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
// 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细
require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细")
var resp map[string]any
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
data := resp["data"].(map[string]any)
require.Equal(t, float64(2), data["success"], "应有 2 个成功")
require.Equal(t, float64(1), data["failed"], "应有 1 个失败")
// 所有 3 个账号都会被尝试更新(非 fail-fast)
require.Equal(t, int64(3), svc.updateCallCount.Load(),
"应调用 3 次 UpdateAccount(逐个尝试,失败后继续)")
}
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
svc := &getAccountFailingService{
stubAdminService: newStubAdminService(),
failOnAccountID: 1,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
}
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
type getAccountFailingService struct {
*stubAdminService
failOnAccountID int64
}
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
if id == f.failOnAccountID {
return nil, errors.New("not found")
}
return f.stubAdminService.GetAccount(ctx, id)
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// intercept_warmup_requests 传入非 bool 类型(string),应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": "not-a-bool",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"intercept_warmup_requests 传入非 bool 值应返回 400")
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": true,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"intercept_warmup_requests 传入合法 bool 值应返回 200")
}
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入非 string 类型(number),应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": 12345,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"account_uuid 传入非 string 值应返回 400")
}
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入 null(设置为空),应正常通过
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": nil,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"account_uuid 传入 null 应返回 200")
}
...@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { ...@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return return
} }
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs) stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get user usage stats") response.Error(c, 500, "Failed to get user usage stats")
return return
...@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) { ...@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return return
} }
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs) stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
if err != nil { if err != nil {
response.Error(c, 500, "Failed to get API key usage stats") response.Error(c, 500, "Failed to get API key usage stats")
return return
......
...@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler { ...@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
type CreateGroupRequest struct { type CreateGroupRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
Description string `json:"description"` Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"` RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"` IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
...@@ -38,6 +38,10 @@ type CreateGroupRequest struct { ...@@ -38,6 +38,10 @@ type CreateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly bool `json:"claude_code_only"` ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
...@@ -55,7 +59,7 @@ type CreateGroupRequest struct { ...@@ -55,7 +59,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct { type UpdateGroupRequest struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"` RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"` IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"`
...@@ -67,6 +71,10 @@ type UpdateGroupRequest struct { ...@@ -67,6 +71,10 @@ type UpdateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly *bool `json:"claude_code_only"` ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
...@@ -179,6 +187,10 @@ func (h *GroupHandler) Create(c *gin.Context) { ...@@ -179,6 +187,10 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K, ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K, ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly, ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID, FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
...@@ -225,6 +237,10 @@ func (h *GroupHandler) Update(c *gin.Context) { ...@@ -225,6 +237,10 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K, ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K, ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K, ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly, ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID, FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
......
package admin
import (
"context"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type idempotencyStoreUnavailableMode int
const (
idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota
idempotencyStoreUnavailableFailOpen
)
func executeAdminIdempotent(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) (*service.IdempotencyExecuteResult, error) {
coordinator := service.DefaultIdempotencyCoordinator()
if coordinator == nil {
data, err := execute(c.Request.Context())
if err != nil {
return nil, err
}
return &service.IdempotencyExecuteResult{Data: data}, nil
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
Scope: scope,
ActorScope: actorScope,
Method: c.Request.Method,
Route: c.FullPath(),
IdempotencyKey: c.GetHeader("Idempotency-Key"),
Payload: payload,
RequireKey: true,
TTL: ttl,
}, execute)
}
func executeAdminIdempotentJSON(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute)
}
func executeAdminIdempotentJSONFailOpenOnStoreUnavailable(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute)
}
func executeAdminIdempotentJSONWithMode(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
mode idempotencyStoreUnavailableMode,
execute func(context.Context) (any, error),
) {
result, err := executeAdminIdempotent(c, scope, payload, ttl, execute)
if err != nil {
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
strategy := "fail_close"
if mode == idempotencyStoreUnavailableFailOpen {
strategy = "fail_open"
}
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy)
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy)
if mode == idempotencyStoreUnavailableFailOpen {
data, fallbackErr := execute(c.Request.Context())
if fallbackErr != nil {
response.ErrorFrom(c, fallbackErr)
return
}
c.Header("X-Idempotency-Degraded", "store-unavailable")
response.Success(c, data)
return
}
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}
package admin
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type storeUnavailableRepoStub struct{}
func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-1")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable")
}
func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-2")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded"))
require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue")
}
type memoryIdempotencyRepoStub struct {
mu sync.Mutex
nextID int64
data map[string]*service.IdempotencyRecord
}
func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub {
return &memoryIdempotencyRepoStub{
nextID: 1,
data: make(map[string]*service.IdempotencyRecord),
}
}
func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string {
return scope + "|" + keyHash
}
func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
return &out
}
func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
cp := r.clone(record)
cp.ID = r.nextID
r.nextID++
r.data[k] = cp
record.ID = cp.ID
return true, nil
}
func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.clone(r.data[r.key(scope, keyHash)]), nil
}
func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = service.IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
rec.ErrorReason = nil
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.ErrorReason = &errorReason
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
return 0, nil
}
func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newMemoryIdempotencyRepoStub()
cfg := service.DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed atomic.Int32
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed.Add(1)
time.Sleep(120 * time.Millisecond)
return gin.H{"ok": true}, nil
})
})
call := func() (int, http.Header) {
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "same-key")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
return rec.Code, rec.Header()
}
var status1, status2 int
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
status1, _ = call()
}()
go func() {
defer wg.Done()
status2, _ = call()
}()
wg.Wait()
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once")
status3, headers3 := call()
require.Equal(t, http.StatusOK, status3)
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
require.Equal(t, int32(1), executed.Load())
}
...@@ -2,6 +2,7 @@ package admin ...@@ -2,6 +2,7 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
...@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct { ...@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
adminService service.AdminService adminService service.AdminService
} }
func oauthPlatformFromPath(c *gin.Context) string {
if strings.Contains(c.FullPath(), "/admin/sora/") {
return service.PlatformSora
}
return service.PlatformOpenAI
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler // NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler { func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{ return &OpenAIOAuthHandler{
...@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { ...@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
type OpenAIExchangeCodeRequest struct { type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"` SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"` Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
} }
...@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { ...@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID, SessionID: req.SessionID,
Code: req.Code, Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
...@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { ...@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token // OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct { type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"` RefreshToken string `json:"refresh_token"`
RT string `json:"rt"`
ClientID string `json:"client_id"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
} }
// RefreshToken refreshes an OpenAI OAuth token // RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token // POST /api/v1/admin/openai/refresh-token
// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error()) response.BadRequest(c, "Invalid request: "+err.Error())
return return
} }
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
refreshToken = strings.TrimSpace(req.RT)
}
if refreshToken == "" {
response.BadRequest(c, "refresh_token is required")
return
}
var proxyURL string var proxyURL string
if req.ProxyID != nil { if req.ProxyID != nil {
...@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { ...@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
} }
} }
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { ...@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo) response.Success(c, tokenInfo)
} }
// RefreshAccountToken refreshes token for a specific OpenAI account // ExchangeSoraSessionToken exchanges Sora session token to access token
// POST /api/v1/admin/sora/st2at
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
var req struct {
SessionToken string `json:"session_token"`
ST string `json:"st"`
ProxyID *int64 `json:"proxy_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken := strings.TrimSpace(req.SessionToken)
if sessionToken == "" {
sessionToken = strings.TrimSpace(req.ST)
}
if sessionToken == "" {
response.BadRequest(c, "session_token is required")
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
// POST /api/v1/admin/openai/accounts/:id/refresh // POST /api/v1/admin/openai/accounts/:id/refresh
// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {
...@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { ...@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
return return
} }
// Ensure account is OpenAI platform platform := oauthPlatformFromPath(c)
if !account.IsOpenAI() { if account.Platform != platform {
response.BadRequest(c, "Account is not an OpenAI account") response.BadRequest(c, "Account platform does not match OAuth endpoint")
return return
} }
...@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { ...@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount)) response.Success(c, dto.AccountFromService(updatedAccount))
} }
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info // CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth // POST /api/v1/admin/openai/create-from-oauth
// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct { var req struct {
SessionID string `json:"session_id" binding:"required"` SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"` Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"` Name string `json:"name"`
...@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { ...@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID, SessionID: req.SessionID,
Code: req.Code, Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
...@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { ...@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
// Build credentials from token info // Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo) credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
platform := oauthPlatformFromPath(c)
// Use email as default name if not provided // Use email as default name if not provided
name := req.Name name := req.Name
if name == "" && tokenInfo.Email != "" { if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email name = tokenInfo.Email
} }
if name == "" { if name == "" {
name = "OpenAI OAuth Account" if platform == service.PlatformSora {
name = "Sora OAuth Account"
} else {
name = "OpenAI OAuth Account"
}
} }
// Create account // Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name, Name: name,
Platform: "openai", Platform: platform,
Type: "oauth", Type: "oauth",
Credentials: credentials, Credentials: credentials,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
......
package admin package admin
import ( import (
"fmt"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
...@@ -218,6 +219,115 @@ func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) { ...@@ -218,6 +219,115 @@ func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) {
response.Success(c, data) response.Success(c, data)
} }
// GetDashboardOpenAITokenStats returns OpenAI token efficiency stats grouped by model.
// GET /api/v1/admin/ops/dashboard/openai-token-stats
func (h *OpsHandler) GetDashboardOpenAITokenStats(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
filter, err := parseOpsOpenAITokenStatsFilter(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
data, err := h.opsService.GetOpenAITokenStats(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
func parseOpsOpenAITokenStatsFilter(c *gin.Context) (*service.OpsOpenAITokenStatsFilter, error) {
if c == nil {
return nil, fmt.Errorf("invalid request")
}
timeRange := strings.TrimSpace(c.Query("time_range"))
if timeRange == "" {
timeRange = "30d"
}
dur, ok := parseOpsOpenAITokenStatsDuration(timeRange)
if !ok {
return nil, fmt.Errorf("invalid time_range")
}
end := time.Now().UTC()
start := end.Add(-dur)
filter := &service.OpsOpenAITokenStatsFilter{
TimeRange: timeRange,
StartTime: start,
EndTime: end,
Platform: strings.TrimSpace(c.Query("platform")),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
return nil, fmt.Errorf("invalid group_id")
}
filter.GroupID = &id
}
topNRaw := strings.TrimSpace(c.Query("top_n"))
pageRaw := strings.TrimSpace(c.Query("page"))
pageSizeRaw := strings.TrimSpace(c.Query("page_size"))
if topNRaw != "" && (pageRaw != "" || pageSizeRaw != "") {
return nil, fmt.Errorf("invalid query: top_n cannot be used with page/page_size")
}
if topNRaw != "" {
topN, err := strconv.Atoi(topNRaw)
if err != nil || topN < 1 || topN > 100 {
return nil, fmt.Errorf("invalid top_n")
}
filter.TopN = topN
return filter, nil
}
filter.Page = 1
filter.PageSize = 20
if pageRaw != "" {
page, err := strconv.Atoi(pageRaw)
if err != nil || page < 1 {
return nil, fmt.Errorf("invalid page")
}
filter.Page = page
}
if pageSizeRaw != "" {
pageSize, err := strconv.Atoi(pageSizeRaw)
if err != nil || pageSize < 1 || pageSize > 100 {
return nil, fmt.Errorf("invalid page_size")
}
filter.PageSize = pageSize
}
return filter, nil
}
func parseOpsOpenAITokenStatsDuration(v string) (time.Duration, bool) {
switch strings.TrimSpace(v) {
case "30m":
return 30 * time.Minute, true
case "1h":
return time.Hour, true
case "1d":
return 24 * time.Hour, true
case "15d":
return 15 * 24 * time.Hour, true
case "30d":
return 30 * 24 * time.Hour, true
default:
return 0, false
}
}
func pickThroughputBucketSeconds(window time.Duration) int { func pickThroughputBucketSeconds(window time.Duration) int {
// Keep buckets predictable and avoid huge responses. // Keep buckets predictable and avoid huge responses.
switch { switch {
......
package admin
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type testSettingRepo struct {
values map[string]string
}
func newTestSettingRepo() *testSettingRepo {
return &testSettingRepo{values: map[string]string{}}
}
func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
v, err := s.GetValue(ctx, key)
if err != nil {
return nil, err
}
return &service.Setting{Key: key, Value: v}, nil
}
func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
v, ok := s.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return v, nil
}
func (s *testSettingRepo) Set(ctx context.Context, key, value string) error {
s.values[key] = value
return nil
}
func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, k := range keys {
if v, ok := s.values[k]; ok {
out[k] = v
}
}
return out, nil
}
func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
for k, v := range settings {
s.values[k] = v
}
return nil
}
func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(s.values))
for k, v := range s.values {
out[k] = v
}
return out, nil
}
func (s *testSettingRepo) Delete(ctx context.Context, key string) error {
delete(s.values, key)
return nil
}
func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7})
c.Next()
})
}
r.GET("/runtime/logging", handler.GetRuntimeLogConfig)
r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig)
r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig)
return r
}
func newRuntimeOpsService(t *testing.T) *service.OpsService {
t.Helper()
if err := logger.Init(logger.InitOptions{
Level: "info",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: false,
ToFile: false,
},
}); err != nil {
t.Fatalf("init logger: %v", err)
}
settingRepo := newTestSettingRepo()
cfg := &config.Config{
Ops: config.OpsConfig{Enabled: true},
Log: config.LogConfig{
Level: "info",
Caller: true,
StacktraceLevel: "error",
Sampling: config.LogSamplingConfig{
Enabled: false,
Initial: 100,
Thereafter: 100,
},
},
}
return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
}
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, true)
payload := map[string]any{
"level": "debug",
"enable_sampling": false,
"sampling_initial": 100,
"sampling_thereafter": 100,
"caller": true,
"stacktrace_level": "error",
"retention_days": 30,
}
raw, _ := json.Marshal(payload)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String())
}
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String())
}
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"net/http" "net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) { ...@@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) {
response.Success(c, updated) response.Success(c, updated)
} }
// GetRuntimeLogConfig returns runtime log config (DB-backed).
// GET /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config")
return
}
response.Success(c, cfg)
}
// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately.
// PUT /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsRuntimeLogConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline.
// POST /api/v1/admin/ops/runtime/logging/reset
func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetAdvancedSettings returns Ops advanced settings (DB-backed). // GetAdvancedSettings returns Ops advanced settings (DB-backed).
// GET /api/v1/admin/ops/advanced-settings // GET /api/v1/admin/ops/advanced-settings
func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) { func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) {
......
package admin
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type opsSystemLogCleanupRequest struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Level string `json:"level"`
Component string `json:"component"`
RequestID string `json:"request_id"`
ClientRequestID string `json:"client_request_id"`
UserID *int64 `json:"user_id"`
AccountID *int64 `json:"account_id"`
Platform string `json:"platform"`
Model string `json:"model"`
Query string `json:"q"`
}
// ListSystemLogs returns indexed system logs.
// GET /api/v1/admin/ops/system-logs
func (h *OpsHandler) ListSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 200 {
pageSize = 200
}
start, end, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsSystemLogFilter{
Page: page,
PageSize: pageSize,
StartTime: &start,
EndTime: &end,
Level: strings.TrimSpace(c.Query("level")),
Component: strings.TrimSpace(c.Query("component")),
RequestID: strings.TrimSpace(c.Query("request_id")),
ClientRequestID: strings.TrimSpace(c.Query("client_request_id")),
Platform: strings.TrimSpace(c.Query("platform")),
Model: strings.TrimSpace(c.Query("model")),
Query: strings.TrimSpace(c.Query("q")),
}
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
filter.UserID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize)
}
// CleanupSystemLogs deletes indexed system logs by filter.
// POST /api/v1/admin/ops/system-logs/cleanup
func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
var req opsSystemLogCleanupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
parseTS := func(raw string) (*time.Time, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, nil
}
if t, err := time.Parse(time.RFC3339Nano, raw); err == nil {
return &t, nil
}
t, err := time.Parse(time.RFC3339, raw)
if err != nil {
return nil, err
}
return &t, nil
}
start, err := parseTS(req.StartTime)
if err != nil {
response.BadRequest(c, "Invalid start_time")
return
}
end, err := parseTS(req.EndTime)
if err != nil {
response.BadRequest(c, "Invalid end_time")
return
}
filter := &service.OpsSystemLogCleanupFilter{
StartTime: start,
EndTime: end,
Level: strings.TrimSpace(req.Level),
Component: strings.TrimSpace(req.Component),
RequestID: strings.TrimSpace(req.RequestID),
ClientRequestID: strings.TrimSpace(req.ClientRequestID),
UserID: req.UserID,
AccountID: req.AccountID,
Platform: strings.TrimSpace(req.Platform),
Model: strings.TrimSpace(req.Model),
Query: strings.TrimSpace(req.Query),
}
deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": deleted})
}
// GetSystemLogIngestionHealth returns sink health metrics.
// GET /api/v1/admin/ops/system-logs/health
func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.opsService.GetSystemLogSinkHealth())
}
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type responseEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}
func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99})
c.Next()
})
}
r.GET("/logs", handler.ListSystemLogs)
r.POST("/logs/cleanup", handler.CleanupSystemLogs)
r.GET("/logs/health", handler.GetSystemLogIngestionHealth)
return r
}
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
var resp responseEnvelope
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if resp.Code != 0 {
t.Fatalf("unexpected response code: %+v", resp)
}
}
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_Health(t *testing.T) {
sink := service.NewOpsSystemLogSink(nil)
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h = NewOpsHandler(svc)
r = newOpsSystemLogTestRouter(h, false)
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
...@@ -3,7 +3,6 @@ package admin ...@@ -3,7 +3,6 @@ package admin
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"log"
"math" "math"
"net" "net"
"net/http" "net/http"
...@@ -16,6 +15,7 @@ import ( ...@@ -16,6 +15,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
...@@ -252,7 +252,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) { ...@@ -252,7 +252,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now) stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now)
if err != nil || stats == nil { if err != nil || stats == nil {
if err != nil { if err != nil {
log.Printf("[OpsWS] refresh: get window stats failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err)
} }
return return
} }
...@@ -278,7 +278,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) { ...@@ -278,7 +278,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
msg, err := json.Marshal(payload) msg, err := json.Marshal(payload)
if err != nil { if err != nil {
log.Printf("[OpsWS] refresh: marshal payload failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err)
return return
} }
...@@ -338,7 +338,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { ...@@ -338,7 +338,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
// Reserve a global slot before upgrading the connection to keep the limit strict. // Reserve a global slot before upgrading the connection to keep the limit strict.
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) { if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return return
} }
...@@ -350,7 +350,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { ...@@ -350,7 +350,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" { if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) { if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return return
} }
...@@ -359,7 +359,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { ...@@ -359,7 +359,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil { if err != nil {
log.Printf("[OpsWS] upgrade failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err)
return return
} }
...@@ -452,7 +452,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { ...@@ -452,7 +452,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
conn.SetReadLimit(qpsWSMaxReadBytes) conn.SetReadLimit(qpsWSMaxReadBytes)
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
log.Printf("[OpsWS] set read deadline failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err)
return return
} }
conn.SetPongHandler(func(string) error { conn.SetPongHandler(func(string) error {
...@@ -471,7 +471,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { ...@@ -471,7 +471,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
_, _, err := conn.ReadMessage() _, _, err := conn.ReadMessage()
if err != nil { if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Printf("[OpsWS] read failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err)
} }
return return
} }
...@@ -508,7 +508,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { ...@@ -508,7 +508,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
continue continue
} }
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil { if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
log.Printf("[OpsWS] write failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err)
cancel() cancel()
closeConn() closeConn()
wg.Wait() wg.Wait()
...@@ -517,7 +517,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { ...@@ -517,7 +517,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
case <-pingTicker.C: case <-pingTicker.C:
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil { if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
log.Printf("[OpsWS] ping failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err)
cancel() cancel()
closeConn() closeConn()
wg.Wait() wg.Wait()
...@@ -666,14 +666,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig { ...@@ -666,14 +666,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
if parsed, err := strconv.ParseBool(v); err == nil { if parsed, err := strconv.ParseBool(v); err == nil {
cfg.TrustProxy = parsed cfg.TrustProxy = parsed
} else { } else {
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
} }
} }
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" { if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
prefixes, invalid := parseTrustedProxyList(raw) prefixes, invalid := parseTrustedProxyList(raw)
if len(invalid) > 0 { if len(invalid) > 0 {
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", ")) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
} }
cfg.TrustedProxies = prefixes cfg.TrustedProxies = prefixes
} }
...@@ -684,7 +684,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig { ...@@ -684,7 +684,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
case OriginPolicyStrict, OriginPolicyPermissive: case OriginPolicyStrict, OriginPolicyPermissive:
cfg.OriginPolicy = normalized cfg.OriginPolicy = normalized
default: default:
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
} }
} }
...@@ -701,14 +701,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits { ...@@ -701,14 +701,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits {
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 { if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
cfg.MaxConns = int32(parsed) cfg.MaxConns = int32(parsed)
} else { } else {
log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
} }
} }
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" { if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" {
if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 { if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 {
cfg.MaxConnsPerIP = int32(parsed) cfg.MaxConnsPerIP = int32(parsed)
} else { } else {
log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
} }
} }
return cfg return cfg
......
package admin package admin
import ( import (
"context"
"strconv" "strconv"
"strings" "strings"
...@@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) { ...@@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) {
return return
} }
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
Name: strings.TrimSpace(req.Name), proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Protocol: strings.TrimSpace(req.Protocol), Name: strings.TrimSpace(req.Name),
Host: strings.TrimSpace(req.Host), Protocol: strings.TrimSpace(req.Protocol),
Port: req.Port, Host: strings.TrimSpace(req.Host),
Username: strings.TrimSpace(req.Username), Port: req.Port,
Password: strings.TrimSpace(req.Password), Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
})
if err != nil {
return nil, err
}
return dto.ProxyFromService(proxy), nil
}) })
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.ProxyFromService(proxy))
} }
// Update handles updating a proxy // Update handles updating a proxy
...@@ -236,6 +237,24 @@ func (h *ProxyHandler) Test(c *gin.Context) { ...@@ -236,6 +237,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
response.Success(c, result) response.Success(c, result)
} }
// CheckQuality handles checking proxy quality across common AI targets.
// POST /api/v1/admin/proxies/:id/quality-check
func (h *ProxyHandler) CheckQuality(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// GetStats handles getting proxy statistics // GetStats handles getting proxy statistics
// GET /api/v1/admin/proxies/:id/stats // GET /api/v1/admin/proxies/:id/stats
func (h *ProxyHandler) GetStats(c *gin.Context) { func (h *ProxyHandler) GetStats(c *gin.Context) {
......
...@@ -2,6 +2,7 @@ package admin ...@@ -2,6 +2,7 @@ package admin
import ( import (
"bytes" "bytes"
"context"
"encoding/csv" "encoding/csv"
"fmt" "fmt"
"strconv" "strconv"
...@@ -88,23 +89,24 @@ func (h *RedeemHandler) Generate(c *gin.Context) { ...@@ -88,23 +89,24 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return return
} }
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{ executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
Count: req.Count, codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
Type: req.Type, Count: req.Count,
Value: req.Value, Type: req.Type,
GroupID: req.GroupID, Value: req.Value,
ValidityDays: req.ValidityDays, GroupID: req.GroupID,
}) ValidityDays: req.ValidityDays,
if err != nil { })
response.ErrorFrom(c, err) if execErr != nil {
return return nil, execErr
} }
out := make([]dto.AdminRedeemCode, 0, len(codes)) out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes { for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
} }
response.Success(c, out) return out, nil
})
} }
// Delete handles deleting a redeem code // Delete handles deleting a redeem code
......
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
func truncateSearchByRune(search string, maxRunes int) string {
if runes := []rune(search); len(runes) > maxRunes {
return string(runes[:maxRunes])
}
return search
}
func TestTruncateSearchByRune(t *testing.T) {
tests := []struct {
name string
input string
maxRunes int
wantLen int // 期望的 rune 长度
}{
{
name: "纯中文超长",
input: string(make([]rune, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "纯 ASCII 超长",
input: string(make([]byte, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "空字符串",
input: "",
maxRunes: 100,
wantLen: 0,
},
{
name: "恰好 100 个字符",
input: string(make([]rune, 100)),
maxRunes: 100,
wantLen: 100,
},
{
name: "不足 100 字符不截断",
input: "hello世界",
maxRunes: 100,
wantLen: 7,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := truncateSearchByRune(tc.input, tc.maxRunes)
require.Equal(t, tc.wantLen, len([]rune(result)))
})
}
}
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
input := ""
for i := 0; i < 101; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
require.Equal(t, 100, len([]rune(result)))
// 验证截断结果是有效的 UTF-8(每个中文字符 3 字节)
require.Equal(t, 300, len(result))
}
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
// 50 个 ASCII + 51 个中文 = 101 个 rune
input := ""
for i := 0; i < 50; i++ {
input += "a"
}
for i := 0; i < 51; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
runes := []rune(result)
require.Equal(t, 100, len(runes))
// 前 50 个应该是 'a',后 50 个应该是 '中'
require.Equal(t, 'a', runes[0])
require.Equal(t, 'a', runes[49])
require.Equal(t, '中', runes[50])
require.Equal(t, '中', runes[99])
}
package admin package admin
import ( import (
"context"
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
...@@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { ...@@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return return
} }
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days) idempotencyPayload := struct {
if err != nil { SubscriptionID int64 `json:"subscription_id"`
response.ErrorFrom(c, err) Body AdjustSubscriptionRequest `json:"body"`
return }{
SubscriptionID: subscriptionID,
Body: req,
} }
executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription)) subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days)
if execErr != nil {
return nil, execErr
}
return dto.UserSubscriptionFromServiceAdmin(subscription), nil
})
} }
// Revoke handles revoking a subscription // Revoke handles revoking a subscription
......
package admin package admin
import ( import (
"context"
"net/http" "net/http"
"strconv"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil" "github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -14,12 +18,14 @@ import ( ...@@ -14,12 +18,14 @@ import (
// SystemHandler handles system-related operations // SystemHandler handles system-related operations
type SystemHandler struct { type SystemHandler struct {
updateSvc *service.UpdateService updateSvc *service.UpdateService
lockSvc *service.SystemOperationLockService
} }
// NewSystemHandler creates a new SystemHandler // NewSystemHandler creates a new SystemHandler
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler { func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
return &SystemHandler{ return &SystemHandler{
updateSvc: updateSvc, updateSvc: updateSvc,
lockSvc: lockSvc,
} }
} }
...@@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) { ...@@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) {
// PerformUpdate downloads and applies the update // PerformUpdate downloads and applies the update
// POST /api/v1/admin/system/update // POST /api/v1/admin/system/update
func (h *SystemHandler) PerformUpdate(c *gin.Context) { func (h *SystemHandler) PerformUpdate(c *gin.Context) {
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil { operationID := buildSystemOperationID(c, "update")
response.Error(c, http.StatusInternalServerError, err.Error()) payload := gin.H{"operation_id": operationID}
return executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
} lock, release, err := h.acquireSystemLock(ctx, operationID)
response.Success(c, gin.H{ if err != nil {
"message": "Update completed. Please restart the service.", return nil, err
"need_restart": true, }
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
releaseReason = "SYSTEM_UPDATE_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
}) })
} }
// Rollback restores the previous version // Rollback restores the previous version
// POST /api/v1/admin/system/rollback // POST /api/v1/admin/system/rollback
func (h *SystemHandler) Rollback(c *gin.Context) { func (h *SystemHandler) Rollback(c *gin.Context) {
if err := h.updateSvc.Rollback(); err != nil { operationID := buildSystemOperationID(c, "rollback")
response.Error(c, http.StatusInternalServerError, err.Error()) payload := gin.H{"operation_id": operationID}
return executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
} lock, release, err := h.acquireSystemLock(ctx, operationID)
response.Success(c, gin.H{ if err != nil {
"message": "Rollback completed. Please restart the service.", return nil, err
"need_restart": true, }
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.Rollback(); err != nil {
releaseReason = "SYSTEM_ROLLBACK_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
}) })
} }
// RestartService restarts the systemd service // RestartService restarts the systemd service
// POST /api/v1/admin/system/restart // POST /api/v1/admin/system/restart
func (h *SystemHandler) RestartService(c *gin.Context) { func (h *SystemHandler) RestartService(c *gin.Context) {
// Schedule service restart in background after sending response operationID := buildSystemOperationID(c, "restart")
// This ensures the client receives the success response before the service restarts payload := gin.H{"operation_id": operationID}
go func() { executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
// Wait a moment to ensure the response is sent lock, release, err := h.acquireSystemLock(ctx, operationID)
time.Sleep(500 * time.Millisecond) if err != nil {
sysutil.RestartServiceAsync() return nil, err
}() }
succeeded := false
defer func() {
release("", succeeded)
}()
response.Success(c, gin.H{ // Schedule service restart in background after sending response
"message": "Service restart initiated", // This ensures the client receives the success response before the service restarts
go func() {
// Wait a moment to ensure the response is sent
time.Sleep(500 * time.Millisecond)
sysutil.RestartServiceAsync()
}()
succeeded = true
return gin.H{
"message": "Service restart initiated",
"operation_id": lock.OperationID(),
}, nil
}) })
} }
func (h *SystemHandler) acquireSystemLock(
ctx context.Context,
operationID string,
) (*service.SystemOperationLock, func(string, bool), error) {
if h.lockSvc == nil {
return nil, nil, service.ErrIdempotencyStoreUnavail
}
lock, err := h.lockSvc.Acquire(ctx, operationID)
if err != nil {
return nil, nil, err
}
release := func(reason string, succeeded bool) {
releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason)
}
return lock, release, nil
}
func buildSystemOperationID(c *gin.Context, operation string) string {
key := strings.TrimSpace(c.GetHeader("Idempotency-Key"))
if key == "" {
return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36)
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key
hash := service.HashIdempotencyKey(seed)
if len(hash) > 24 {
hash = hash[:24]
}
return "sysop-" + hash
}
package admin package admin
import ( import (
"log" "context"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
...@@ -378,11 +379,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) { ...@@ -378,11 +379,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
operator = subject.UserID operator = subject.UserID
} }
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize) logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params) tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
if err != nil { if err != nil {
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err) logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
...@@ -390,7 +391,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) { ...@@ -390,7 +391,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
for i := range tasks { for i := range tasks {
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i])) out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
} }
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize) logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
response.Paginated(c, out, result.Total, page, pageSize) response.Paginated(c, out, result.Total, page, pageSize)
} }
...@@ -472,29 +473,36 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { ...@@ -472,29 +473,36 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
billingType = *filters.BillingType billingType = *filters.BillingType
} }
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", idempotencyPayload := struct {
subject.UserID, OperatorID int64 `json:"operator_id"`
filters.StartTime.Format(time.RFC3339), Body CreateUsageCleanupTaskRequest `json:"body"`
filters.EndTime.Format(time.RFC3339), }{
userID, OperatorID: subject.UserID,
apiKeyID, Body: req,
accountID, }
groupID, executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
model, logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
stream, subject.UserID,
billingType, filters.StartTime.Format(time.RFC3339),
req.Timezone, filters.EndTime.Format(time.RFC3339),
) userID,
apiKeyID,
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID) accountID,
if err != nil { groupID,
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err) model,
response.ErrorFrom(c, err) stream,
return billingType,
} req.Timezone,
)
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
response.Success(c, dto.UsageCleanupTaskFromService(task)) task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID)
if err != nil {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
return nil, err
}
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
return dto.UsageCleanupTaskFromService(task), nil
})
} }
// CancelCleanupTask handles canceling a usage cleanup task // CancelCleanupTask handles canceling a usage cleanup task
...@@ -515,12 +523,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) { ...@@ -515,12 +523,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
response.BadRequest(c, "Invalid task id") response.BadRequest(c, "Invalid task id")
return return
} }
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID) logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil { if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err) logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID) logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled}) response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
} }
package admin package admin
import ( import (
"context"
"strconv" "strconv"
"strings" "strings"
...@@ -78,8 +79,8 @@ func (h *UserHandler) List(c *gin.Context) { ...@@ -78,8 +79,8 @@ func (h *UserHandler) List(c *gin.Context) {
search := c.Query("search") search := c.Query("search")
// 标准化和验证 search 参数 // 标准化和验证 search 参数
search = strings.TrimSpace(search) search = strings.TrimSpace(search)
if len(search) > 100 { if runes := []rune(search); len(runes) > 100 {
search = search[:100] search = string(runes[:100])
} }
filters := service.UserListFilters{ filters := service.UserListFilters{
...@@ -257,13 +258,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) { ...@@ -257,13 +258,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return return
} }
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes) idempotencyPayload := struct {
if err != nil { UserID int64 `json:"user_id"`
response.ErrorFrom(c, err) Body UpdateBalanceRequest `json:"body"`
return }{
UserID: userID,
Body: req,
} }
executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
response.Success(c, dto.UserFromServiceAdmin(user)) user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes)
if execErr != nil {
return nil, execErr
}
return dto.UserFromServiceAdmin(user), nil
})
} }
// GetUserAPIKeys handles getting user's API keys // GetUserAPIKeys handles getting user's API keys
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
package handler package handler
import ( import (
"context"
"strconv" "strconv"
"time" "time"
...@@ -130,13 +131,14 @@ func (h *APIKeyHandler) Create(c *gin.Context) { ...@@ -130,13 +131,14 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
if req.Quota != nil { if req.Quota != nil {
svcReq.Quota = *req.Quota svcReq.Quota = *req.Quota
} }
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.APIKeyFromService(key)) executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq)
if err != nil {
return nil, err
}
return dto.APIKeyFromService(key), nil
})
} }
// Update handles updating an API key // Update handles updating an API key
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment