Commit 3d79773b authored by kyx236's avatar kyx236
Browse files

Merge branch 'main' of https://github.com/james-6-23/sub2api

parents 6aa8cbbf 742e73c9
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNormalizeInt64IDList(t *testing.T) {
tests := []struct {
name string
in []int64
want []int64
}{
{"nil input", nil, nil},
{"empty input", []int64{}, nil},
{"single element", []int64{5}, []int64{5}},
{"already sorted unique", []int64{1, 2, 3}, []int64{1, 2, 3}},
{"duplicates removed", []int64{3, 1, 3, 2, 1}, []int64{1, 2, 3}},
{"zero filtered", []int64{0, 1, 2}, []int64{1, 2}},
{"negative filtered", []int64{-5, -1, 3}, []int64{3}},
{"all invalid", []int64{0, -1, -2}, []int64{}},
{"sorted output", []int64{9, 3, 7, 1}, []int64{1, 3, 7, 9}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := normalizeInt64IDList(tc.in)
if tc.want == nil {
require.Nil(t, got)
} else {
require.Equal(t, tc.want, got)
}
})
}
}
func TestBuildAccountTodayStatsBatchCacheKey(t *testing.T) {
tests := []struct {
name string
ids []int64
want string
}{
{"empty", nil, "accounts_today_stats_empty"},
{"single", []int64{42}, "accounts_today_stats:42"},
{"multiple", []int64{1, 2, 3}, "accounts_today_stats:1,2,3"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := buildAccountTodayStatsBatchCacheKey(tc.ids)
require.Equal(t, tc.want, got)
})
}
}
package admin
import (
"context"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type idempotencyStoreUnavailableMode int
const (
idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota
idempotencyStoreUnavailableFailOpen
)
func executeAdminIdempotent(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) (*service.IdempotencyExecuteResult, error) {
coordinator := service.DefaultIdempotencyCoordinator()
if coordinator == nil {
data, err := execute(c.Request.Context())
if err != nil {
return nil, err
}
return &service.IdempotencyExecuteResult{Data: data}, nil
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
Scope: scope,
ActorScope: actorScope,
Method: c.Request.Method,
Route: c.FullPath(),
IdempotencyKey: c.GetHeader("Idempotency-Key"),
Payload: payload,
RequireKey: true,
TTL: ttl,
}, execute)
}
func executeAdminIdempotentJSON(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute)
}
func executeAdminIdempotentJSONFailOpenOnStoreUnavailable(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute)
}
func executeAdminIdempotentJSONWithMode(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
mode idempotencyStoreUnavailableMode,
execute func(context.Context) (any, error),
) {
result, err := executeAdminIdempotent(c, scope, payload, ttl, execute)
if err != nil {
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
strategy := "fail_close"
if mode == idempotencyStoreUnavailableFailOpen {
strategy = "fail_open"
}
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy)
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy)
if mode == idempotencyStoreUnavailableFailOpen {
data, fallbackErr := execute(c.Request.Context())
if fallbackErr != nil {
response.ErrorFrom(c, fallbackErr)
return
}
c.Header("X-Idempotency-Degraded", "store-unavailable")
response.Success(c, data)
return
}
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}
package admin
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type storeUnavailableRepoStub struct{}
func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-1")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable")
}
func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-2")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded"))
require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue")
}
type memoryIdempotencyRepoStub struct {
mu sync.Mutex
nextID int64
data map[string]*service.IdempotencyRecord
}
func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub {
return &memoryIdempotencyRepoStub{
nextID: 1,
data: make(map[string]*service.IdempotencyRecord),
}
}
func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string {
return scope + "|" + keyHash
}
func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
return &out
}
func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
cp := r.clone(record)
cp.ID = r.nextID
r.nextID++
r.data[k] = cp
record.ID = cp.ID
return true, nil
}
func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.clone(r.data[r.key(scope, keyHash)]), nil
}
func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = service.IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
rec.ErrorReason = nil
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.ErrorReason = &errorReason
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
return 0, nil
}
func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newMemoryIdempotencyRepoStub()
cfg := service.DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed atomic.Int32
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed.Add(1)
time.Sleep(120 * time.Millisecond)
return gin.H{"ok": true}, nil
})
})
call := func() (int, http.Header) {
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "same-key")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
return rec.Code, rec.Header()
}
var status1, status2 int
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
status1, _ = call()
}()
go func() {
defer wg.Done()
status2, _ = call()
}()
wg.Wait()
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once")
status3, headers3 := call()
require.Equal(t, http.StatusOK, status3)
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
require.Equal(t, int32(1), executed.Load())
}
...@@ -2,8 +2,10 @@ package admin ...@@ -2,8 +2,10 @@ package admin
import ( import (
"strconv" "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -16,6 +18,13 @@ type OpenAIOAuthHandler struct { ...@@ -16,6 +18,13 @@ type OpenAIOAuthHandler struct {
adminService service.AdminService adminService service.AdminService
} }
func oauthPlatformFromPath(c *gin.Context) string {
if strings.Contains(c.FullPath(), "/admin/sora/") {
return service.PlatformSora
}
return service.PlatformOpenAI
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler // NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler { func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{ return &OpenAIOAuthHandler{
...@@ -39,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { ...@@ -39,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
req = OpenAIGenerateAuthURLRequest{} req = OpenAIGenerateAuthURLRequest{}
} }
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI) result, err := h.openaiOAuthService.GenerateAuthURL(
c.Request.Context(),
req.ProxyID,
req.RedirectURI,
oauthPlatformFromPath(c),
)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -52,6 +66,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { ...@@ -52,6 +66,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
type OpenAIExchangeCodeRequest struct { type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"` SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"` Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
} }
...@@ -68,6 +83,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { ...@@ -68,6 +83,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID, SessionID: req.SessionID,
Code: req.Code, Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
...@@ -81,18 +97,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { ...@@ -81,18 +97,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token // OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct { type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"` RefreshToken string `json:"refresh_token"`
RT string `json:"rt"`
ClientID string `json:"client_id"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
} }
// RefreshToken refreshes an OpenAI OAuth token // RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token // POST /api/v1/admin/openai/refresh-token
// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error()) response.BadRequest(c, "Invalid request: "+err.Error())
return return
} }
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
refreshToken = strings.TrimSpace(req.RT)
}
if refreshToken == "" {
response.BadRequest(c, "refresh_token is required")
return
}
var proxyURL string var proxyURL string
if req.ProxyID != nil { if req.ProxyID != nil {
...@@ -102,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { ...@@ -102,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
} }
} }
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) // 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜
clientID := strings.TrimSpace(req.ClientID)
if clientID == "" {
platform := oauthPlatformFromPath(c)
clientID, _ = openai.OAuthClientConfigByPlatform(platform)
}
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -111,8 +145,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { ...@@ -111,8 +145,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo) response.Success(c, tokenInfo)
} }
// RefreshAccountToken refreshes token for a specific OpenAI account // ExchangeSoraSessionToken exchanges Sora session token to access token
// POST /api/v1/admin/sora/st2at
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
var req struct {
SessionToken string `json:"session_token"`
ST string `json:"st"`
ProxyID *int64 `json:"proxy_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken := strings.TrimSpace(req.SessionToken)
if sessionToken == "" {
sessionToken = strings.TrimSpace(req.ST)
}
if sessionToken == "" {
response.BadRequest(c, "session_token is required")
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
// POST /api/v1/admin/openai/accounts/:id/refresh // POST /api/v1/admin/openai/accounts/:id/refresh
// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {
...@@ -127,9 +192,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { ...@@ -127,9 +192,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
return return
} }
// Ensure account is OpenAI platform platform := oauthPlatformFromPath(c)
if !account.IsOpenAI() { if account.Platform != platform {
response.BadRequest(c, "Account is not an OpenAI account") response.BadRequest(c, "Account platform does not match OAuth endpoint")
return return
} }
...@@ -167,12 +232,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { ...@@ -167,12 +232,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount)) response.Success(c, dto.AccountFromService(updatedAccount))
} }
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info // CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth // POST /api/v1/admin/openai/create-from-oauth
// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct { var req struct {
SessionID string `json:"session_id" binding:"required"` SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"` Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"` ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"` Name string `json:"name"`
...@@ -189,6 +256,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { ...@@ -189,6 +256,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID, SessionID: req.SessionID,
Code: req.Code, Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI, RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
...@@ -200,19 +268,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { ...@@ -200,19 +268,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
// Build credentials from token info // Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo) credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
platform := oauthPlatformFromPath(c)
// Use email as default name if not provided // Use email as default name if not provided
name := req.Name name := req.Name
if name == "" && tokenInfo.Email != "" { if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email name = tokenInfo.Email
} }
if name == "" { if name == "" {
name = "OpenAI OAuth Account" if platform == service.PlatformSora {
name = "Sora OAuth Account"
} else {
name = "OpenAI OAuth Account"
}
} }
// Create account // Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name, Name: name,
Platform: "openai", Platform: platform,
Type: "oauth", Type: "oauth",
Credentials: credentials, Credentials: credentials,
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
......
package admin package admin
import ( import (
"fmt"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
...@@ -218,6 +219,115 @@ func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) { ...@@ -218,6 +219,115 @@ func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) {
response.Success(c, data) response.Success(c, data)
} }
// GetDashboardOpenAITokenStats returns OpenAI token efficiency stats grouped by model.
// GET /api/v1/admin/ops/dashboard/openai-token-stats
func (h *OpsHandler) GetDashboardOpenAITokenStats(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
filter, err := parseOpsOpenAITokenStatsFilter(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
data, err := h.opsService.GetOpenAITokenStats(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
func parseOpsOpenAITokenStatsFilter(c *gin.Context) (*service.OpsOpenAITokenStatsFilter, error) {
if c == nil {
return nil, fmt.Errorf("invalid request")
}
timeRange := strings.TrimSpace(c.Query("time_range"))
if timeRange == "" {
timeRange = "30d"
}
dur, ok := parseOpsOpenAITokenStatsDuration(timeRange)
if !ok {
return nil, fmt.Errorf("invalid time_range")
}
end := time.Now().UTC()
start := end.Add(-dur)
filter := &service.OpsOpenAITokenStatsFilter{
TimeRange: timeRange,
StartTime: start,
EndTime: end,
Platform: strings.TrimSpace(c.Query("platform")),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
return nil, fmt.Errorf("invalid group_id")
}
filter.GroupID = &id
}
topNRaw := strings.TrimSpace(c.Query("top_n"))
pageRaw := strings.TrimSpace(c.Query("page"))
pageSizeRaw := strings.TrimSpace(c.Query("page_size"))
if topNRaw != "" && (pageRaw != "" || pageSizeRaw != "") {
return nil, fmt.Errorf("invalid query: top_n cannot be used with page/page_size")
}
if topNRaw != "" {
topN, err := strconv.Atoi(topNRaw)
if err != nil || topN < 1 || topN > 100 {
return nil, fmt.Errorf("invalid top_n")
}
filter.TopN = topN
return filter, nil
}
filter.Page = 1
filter.PageSize = 20
if pageRaw != "" {
page, err := strconv.Atoi(pageRaw)
if err != nil || page < 1 {
return nil, fmt.Errorf("invalid page")
}
filter.Page = page
}
if pageSizeRaw != "" {
pageSize, err := strconv.Atoi(pageSizeRaw)
if err != nil || pageSize < 1 || pageSize > 100 {
return nil, fmt.Errorf("invalid page_size")
}
filter.PageSize = pageSize
}
return filter, nil
}
func parseOpsOpenAITokenStatsDuration(v string) (time.Duration, bool) {
switch strings.TrimSpace(v) {
case "30m":
return 30 * time.Minute, true
case "1h":
return time.Hour, true
case "1d":
return 24 * time.Hour, true
case "15d":
return 15 * 24 * time.Hour, true
case "30d":
return 30 * 24 * time.Hour, true
default:
return 0, false
}
}
func pickThroughputBucketSeconds(window time.Duration) int { func pickThroughputBucketSeconds(window time.Duration) int {
// Keep buckets predictable and avoid huge responses. // Keep buckets predictable and avoid huge responses.
switch { switch {
......
package admin
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type testSettingRepo struct {
values map[string]string
}
func newTestSettingRepo() *testSettingRepo {
return &testSettingRepo{values: map[string]string{}}
}
func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
v, err := s.GetValue(ctx, key)
if err != nil {
return nil, err
}
return &service.Setting{Key: key, Value: v}, nil
}
func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
v, ok := s.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return v, nil
}
func (s *testSettingRepo) Set(ctx context.Context, key, value string) error {
s.values[key] = value
return nil
}
func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, k := range keys {
if v, ok := s.values[k]; ok {
out[k] = v
}
}
return out, nil
}
func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
for k, v := range settings {
s.values[k] = v
}
return nil
}
func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(s.values))
for k, v := range s.values {
out[k] = v
}
return out, nil
}
func (s *testSettingRepo) Delete(ctx context.Context, key string) error {
delete(s.values, key)
return nil
}
func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7})
c.Next()
})
}
r.GET("/runtime/logging", handler.GetRuntimeLogConfig)
r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig)
r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig)
return r
}
func newRuntimeOpsService(t *testing.T) *service.OpsService {
t.Helper()
if err := logger.Init(logger.InitOptions{
Level: "info",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: false,
ToFile: false,
},
}); err != nil {
t.Fatalf("init logger: %v", err)
}
settingRepo := newTestSettingRepo()
cfg := &config.Config{
Ops: config.OpsConfig{Enabled: true},
Log: config.LogConfig{
Level: "info",
Caller: true,
StacktraceLevel: "error",
Sampling: config.LogSamplingConfig{
Enabled: false,
Initial: 100,
Thereafter: 100,
},
},
}
return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
}
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, true)
payload := map[string]any{
"level": "debug",
"enable_sampling": false,
"sampling_initial": 100,
"sampling_thereafter": 100,
"caller": true,
"stacktrace_level": "error",
"retention_days": 30,
}
raw, _ := json.Marshal(payload)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String())
}
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String())
}
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"net/http" "net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
...@@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) { ...@@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) {
response.Success(c, updated) response.Success(c, updated)
} }
// GetRuntimeLogConfig returns runtime log config (DB-backed).
// GET /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config")
return
}
response.Success(c, cfg)
}
// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately.
// PUT /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsRuntimeLogConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline.
// POST /api/v1/admin/ops/runtime/logging/reset
func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetAdvancedSettings returns Ops advanced settings (DB-backed). // GetAdvancedSettings returns Ops advanced settings (DB-backed).
// GET /api/v1/admin/ops/advanced-settings // GET /api/v1/admin/ops/advanced-settings
func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) { func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) {
......
package admin
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
)
var opsDashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second)
type opsDashboardSnapshotV2Response struct {
GeneratedAt string `json:"generated_at"`
Overview *service.OpsDashboardOverview `json:"overview"`
ThroughputTrend *service.OpsThroughputTrendResponse `json:"throughput_trend"`
ErrorTrend *service.OpsErrorTrendResponse `json:"error_trend"`
}
type opsDashboardSnapshotV2CacheKey struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Platform string `json:"platform"`
GroupID *int64 `json:"group_id"`
QueryMode service.OpsQueryMode `json:"mode"`
BucketSecond int `json:"bucket_second"`
}
// GetDashboardSnapshotV2 returns ops dashboard core snapshot in one request.
// GET /api/v1/admin/ops/dashboard/snapshot-v2
func (h *OpsHandler) GetDashboardSnapshotV2(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
startTime, endTime, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsDashboardFilter{
StartTime: startTime,
EndTime: endTime,
Platform: strings.TrimSpace(c.Query("platform")),
QueryMode: parseOpsQueryMode(c),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid group_id")
return
}
filter.GroupID = &id
}
bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime))
keyRaw, _ := json.Marshal(opsDashboardSnapshotV2CacheKey{
StartTime: startTime.UTC().Format(time.RFC3339),
EndTime: endTime.UTC().Format(time.RFC3339),
Platform: filter.Platform,
GroupID: filter.GroupID,
QueryMode: filter.QueryMode,
BucketSecond: bucketSeconds,
})
cacheKey := string(keyRaw)
if cached, ok := opsDashboardSnapshotV2Cache.Get(cacheKey); ok {
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) {
c.Status(http.StatusNotModified)
return
}
}
c.Header("X-Snapshot-Cache", "hit")
response.Success(c, cached.Payload)
return
}
var (
overview *service.OpsDashboardOverview
trend *service.OpsThroughputTrendResponse
errTrend *service.OpsErrorTrendResponse
)
g, gctx := errgroup.WithContext(c.Request.Context())
g.Go(func() error {
f := *filter
result, err := h.opsService.GetDashboardOverview(gctx, &f)
if err != nil {
return err
}
overview = result
return nil
})
g.Go(func() error {
f := *filter
result, err := h.opsService.GetThroughputTrend(gctx, &f, bucketSeconds)
if err != nil {
return err
}
trend = result
return nil
})
g.Go(func() error {
f := *filter
result, err := h.opsService.GetErrorTrend(gctx, &f, bucketSeconds)
if err != nil {
return err
}
errTrend = result
return nil
})
if err := g.Wait(); err != nil {
response.ErrorFrom(c, err)
return
}
resp := &opsDashboardSnapshotV2Response{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
Overview: overview,
ThroughputTrend: trend,
ErrorTrend: errTrend,
}
cached := opsDashboardSnapshotV2Cache.Set(cacheKey, resp)
if cached.ETag != "" {
c.Header("ETag", cached.ETag)
c.Header("Vary", "If-None-Match")
}
c.Header("X-Snapshot-Cache", "miss")
response.Success(c, resp)
}
package admin
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type opsSystemLogCleanupRequest struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Level string `json:"level"`
Component string `json:"component"`
RequestID string `json:"request_id"`
ClientRequestID string `json:"client_request_id"`
UserID *int64 `json:"user_id"`
AccountID *int64 `json:"account_id"`
Platform string `json:"platform"`
Model string `json:"model"`
Query string `json:"q"`
}
// ListSystemLogs returns indexed system logs.
// GET /api/v1/admin/ops/system-logs
func (h *OpsHandler) ListSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 200 {
pageSize = 200
}
start, end, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsSystemLogFilter{
Page: page,
PageSize: pageSize,
StartTime: &start,
EndTime: &end,
Level: strings.TrimSpace(c.Query("level")),
Component: strings.TrimSpace(c.Query("component")),
RequestID: strings.TrimSpace(c.Query("request_id")),
ClientRequestID: strings.TrimSpace(c.Query("client_request_id")),
Platform: strings.TrimSpace(c.Query("platform")),
Model: strings.TrimSpace(c.Query("model")),
Query: strings.TrimSpace(c.Query("q")),
}
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
filter.UserID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize)
}
// CleanupSystemLogs deletes indexed system logs by filter.
// POST /api/v1/admin/ops/system-logs/cleanup
func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
var req opsSystemLogCleanupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
parseTS := func(raw string) (*time.Time, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, nil
}
if t, err := time.Parse(time.RFC3339Nano, raw); err == nil {
return &t, nil
}
t, err := time.Parse(time.RFC3339, raw)
if err != nil {
return nil, err
}
return &t, nil
}
start, err := parseTS(req.StartTime)
if err != nil {
response.BadRequest(c, "Invalid start_time")
return
}
end, err := parseTS(req.EndTime)
if err != nil {
response.BadRequest(c, "Invalid end_time")
return
}
filter := &service.OpsSystemLogCleanupFilter{
StartTime: start,
EndTime: end,
Level: strings.TrimSpace(req.Level),
Component: strings.TrimSpace(req.Component),
RequestID: strings.TrimSpace(req.RequestID),
ClientRequestID: strings.TrimSpace(req.ClientRequestID),
UserID: req.UserID,
AccountID: req.AccountID,
Platform: strings.TrimSpace(req.Platform),
Model: strings.TrimSpace(req.Model),
Query: strings.TrimSpace(req.Query),
}
deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": deleted})
}
// GetSystemLogIngestionHealth returns sink health metrics.
// GET /api/v1/admin/ops/system-logs/health
func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.opsService.GetSystemLogSinkHealth())
}
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type responseEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}
func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99})
c.Next()
})
}
r.GET("/logs", handler.ListSystemLogs)
r.POST("/logs/cleanup", handler.CleanupSystemLogs)
r.GET("/logs/health", handler.GetSystemLogIngestionHealth)
return r
}
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
var resp responseEnvelope
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if resp.Code != 0 {
t.Fatalf("unexpected response code: %+v", resp)
}
}
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_Health(t *testing.T) {
sink := service.NewOpsSystemLogSink(nil)
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h = NewOpsHandler(svc)
r = newOpsSystemLogTestRouter(h, false)
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
...@@ -3,7 +3,6 @@ package admin ...@@ -3,7 +3,6 @@ package admin
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"log"
"math" "math"
"net" "net"
"net/http" "net/http"
...@@ -16,6 +15,7 @@ import ( ...@@ -16,6 +15,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
...@@ -62,7 +62,8 @@ const ( ...@@ -62,7 +62,8 @@ const (
) )
var wsConnCount atomic.Int32 var wsConnCount atomic.Int32
var wsConnCountByIP sync.Map // map[string]*atomic.Int32 var wsConnCountByIPMu sync.Mutex
var wsConnCountByIP = make(map[string]int32)
const qpsWSIdleStopDelay = 30 * time.Second const qpsWSIdleStopDelay = 30 * time.Second
...@@ -252,7 +253,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) { ...@@ -252,7 +253,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now) stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now)
if err != nil || stats == nil { if err != nil || stats == nil {
if err != nil { if err != nil {
log.Printf("[OpsWS] refresh: get window stats failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err)
} }
return return
} }
...@@ -278,7 +279,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) { ...@@ -278,7 +279,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
msg, err := json.Marshal(payload) msg, err := json.Marshal(payload)
if err != nil { if err != nil {
log.Printf("[OpsWS] refresh: marshal payload failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err)
return return
} }
...@@ -338,7 +339,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { ...@@ -338,7 +339,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
// Reserve a global slot before upgrading the connection to keep the limit strict. // Reserve a global slot before upgrading the connection to keep the limit strict.
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) { if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return return
} }
...@@ -350,7 +351,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { ...@@ -350,7 +351,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" { if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) { if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return return
} }
...@@ -359,7 +360,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { ...@@ -359,7 +360,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil { if err != nil {
log.Printf("[OpsWS] upgrade failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err)
return return
} }
...@@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool { ...@@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool {
if strings.TrimSpace(clientIP) == "" || limit <= 0 { if strings.TrimSpace(clientIP) == "" || limit <= 0 {
return true return true
} }
wsConnCountByIPMu.Lock()
v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{}) defer wsConnCountByIPMu.Unlock()
counter, ok := v.(*atomic.Int32) current := wsConnCountByIP[clientIP]
if !ok { if current >= limit {
return false return false
} }
wsConnCountByIP[clientIP] = current + 1
for { return true
current := counter.Load()
if current >= limit {
return false
}
if counter.CompareAndSwap(current, current+1) {
return true
}
}
} }
func releaseOpsWSIPSlot(clientIP string) { func releaseOpsWSIPSlot(clientIP string) {
if strings.TrimSpace(clientIP) == "" { if strings.TrimSpace(clientIP) == "" {
return return
} }
wsConnCountByIPMu.Lock()
v, ok := wsConnCountByIP.Load(clientIP) defer wsConnCountByIPMu.Unlock()
current, ok := wsConnCountByIP[clientIP]
if !ok { if !ok {
return return
} }
counter, ok := v.(*atomic.Int32) if current <= 1 {
if !ok { delete(wsConnCountByIP, clientIP)
return return
} }
next := counter.Add(-1) wsConnCountByIP[clientIP] = current - 1
if next <= 0 {
// Best-effort cleanup; safe even if a new slot was acquired concurrently.
wsConnCountByIP.Delete(clientIP)
}
} }
func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
...@@ -452,7 +442,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { ...@@ -452,7 +442,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
conn.SetReadLimit(qpsWSMaxReadBytes) conn.SetReadLimit(qpsWSMaxReadBytes)
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
log.Printf("[OpsWS] set read deadline failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err)
return return
} }
conn.SetPongHandler(func(string) error { conn.SetPongHandler(func(string) error {
...@@ -471,7 +461,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { ...@@ -471,7 +461,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
_, _, err := conn.ReadMessage() _, _, err := conn.ReadMessage()
if err != nil { if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Printf("[OpsWS] read failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err)
} }
return return
} }
...@@ -508,7 +498,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { ...@@ -508,7 +498,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
continue continue
} }
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil { if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
log.Printf("[OpsWS] write failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err)
cancel() cancel()
closeConn() closeConn()
wg.Wait() wg.Wait()
...@@ -517,7 +507,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { ...@@ -517,7 +507,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
case <-pingTicker.C: case <-pingTicker.C:
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil { if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
log.Printf("[OpsWS] ping failed: %v", err) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err)
cancel() cancel()
closeConn() closeConn()
wg.Wait() wg.Wait()
...@@ -666,14 +656,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig { ...@@ -666,14 +656,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
if parsed, err := strconv.ParseBool(v); err == nil { if parsed, err := strconv.ParseBool(v); err == nil {
cfg.TrustProxy = parsed cfg.TrustProxy = parsed
} else { } else {
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
} }
} }
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" { if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
prefixes, invalid := parseTrustedProxyList(raw) prefixes, invalid := parseTrustedProxyList(raw)
if len(invalid) > 0 { if len(invalid) > 0 {
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", ")) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
} }
cfg.TrustedProxies = prefixes cfg.TrustedProxies = prefixes
} }
...@@ -684,7 +674,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig { ...@@ -684,7 +674,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
case OriginPolicyStrict, OriginPolicyPermissive: case OriginPolicyStrict, OriginPolicyPermissive:
cfg.OriginPolicy = normalized cfg.OriginPolicy = normalized
default: default:
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
} }
} }
...@@ -701,14 +691,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits { ...@@ -701,14 +691,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits {
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 { if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
cfg.MaxConns = int32(parsed) cfg.MaxConns = int32(parsed)
} else { } else {
log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
} }
} }
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" { if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" {
if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 { if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 {
cfg.MaxConnsPerIP = int32(parsed) cfg.MaxConnsPerIP = int32(parsed)
} else { } else {
log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP) logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
} }
} }
return cfg return cfg
......
package admin package admin
import ( import (
"context"
"strconv" "strconv"
"strings" "strings"
...@@ -63,9 +64,9 @@ func (h *ProxyHandler) List(c *gin.Context) { ...@@ -63,9 +64,9 @@ func (h *ProxyHandler) List(c *gin.Context) {
return return
} }
out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
for i := range proxies { for i := range proxies {
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
} }
response.Paginated(c, out, total, page, pageSize) response.Paginated(c, out, total, page, pageSize)
} }
...@@ -82,9 +83,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) { ...@@ -82,9 +83,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
for i := range proxies { for i := range proxies {
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
} }
response.Success(c, out) response.Success(c, out)
return return
...@@ -96,9 +97,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) { ...@@ -96,9 +97,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
return return
} }
out := make([]dto.Proxy, 0, len(proxies)) out := make([]dto.AdminProxy, 0, len(proxies))
for i := range proxies { for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i])) out = append(out, *dto.ProxyFromServiceAdmin(&proxies[i]))
} }
response.Success(c, out) response.Success(c, out)
} }
...@@ -118,7 +119,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) { ...@@ -118,7 +119,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
return return
} }
response.Success(c, dto.ProxyFromService(proxy)) response.Success(c, dto.ProxyFromServiceAdmin(proxy))
} }
// Create handles creating a new proxy // Create handles creating a new proxy
...@@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) { ...@@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) {
return return
} }
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
Name: strings.TrimSpace(req.Name), proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Protocol: strings.TrimSpace(req.Protocol), Name: strings.TrimSpace(req.Name),
Host: strings.TrimSpace(req.Host), Protocol: strings.TrimSpace(req.Protocol),
Port: req.Port, Host: strings.TrimSpace(req.Host),
Username: strings.TrimSpace(req.Username), Port: req.Port,
Password: strings.TrimSpace(req.Password), Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
})
if err != nil {
return nil, err
}
return dto.ProxyFromServiceAdmin(proxy), nil
}) })
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.ProxyFromService(proxy))
} }
// Update handles updating a proxy // Update handles updating a proxy
...@@ -175,7 +176,7 @@ func (h *ProxyHandler) Update(c *gin.Context) { ...@@ -175,7 +176,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
return return
} }
response.Success(c, dto.ProxyFromService(proxy)) response.Success(c, dto.ProxyFromServiceAdmin(proxy))
} }
// Delete handles deleting a proxy // Delete handles deleting a proxy
...@@ -236,6 +237,24 @@ func (h *ProxyHandler) Test(c *gin.Context) { ...@@ -236,6 +237,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
response.Success(c, result) response.Success(c, result)
} }
// CheckQuality handles checking proxy quality across common AI targets.
// POST /api/v1/admin/proxies/:id/quality-check
func (h *ProxyHandler) CheckQuality(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// GetStats handles getting proxy statistics // GetStats handles getting proxy statistics
// GET /api/v1/admin/proxies/:id/stats // GET /api/v1/admin/proxies/:id/stats
func (h *ProxyHandler) GetStats(c *gin.Context) { func (h *ProxyHandler) GetStats(c *gin.Context) {
......
...@@ -2,12 +2,15 @@ package admin ...@@ -2,12 +2,15 @@ package admin
import ( import (
"bytes" "bytes"
"context"
"encoding/csv" "encoding/csv"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -16,13 +19,15 @@ import ( ...@@ -16,13 +19,15 @@ import (
// RedeemHandler handles admin redeem code management // RedeemHandler handles admin redeem code management
type RedeemHandler struct { type RedeemHandler struct {
adminService service.AdminService adminService service.AdminService
redeemService *service.RedeemService
} }
// NewRedeemHandler creates a new admin redeem handler // NewRedeemHandler creates a new admin redeem handler
func NewRedeemHandler(adminService service.AdminService) *RedeemHandler { func NewRedeemHandler(adminService service.AdminService, redeemService *service.RedeemService) *RedeemHandler {
return &RedeemHandler{ return &RedeemHandler{
adminService: adminService, adminService: adminService,
redeemService: redeemService,
} }
} }
...@@ -35,6 +40,15 @@ type GenerateRedeemCodesRequest struct { ...@@ -35,6 +40,15 @@ type GenerateRedeemCodesRequest struct {
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年 ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
} }
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
type CreateAndRedeemCodeRequest struct {
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value" binding:"required,gt=0"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
Notes string `json:"notes"`
}
// List handles listing all redeem codes with pagination // List handles listing all redeem codes with pagination
// GET /api/v1/admin/redeem-codes // GET /api/v1/admin/redeem-codes
func (h *RedeemHandler) List(c *gin.Context) { func (h *RedeemHandler) List(c *gin.Context) {
...@@ -88,23 +102,99 @@ func (h *RedeemHandler) Generate(c *gin.Context) { ...@@ -88,23 +102,99 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return return
} }
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{ executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
Count: req.Count, codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
Type: req.Type, Count: req.Count,
Value: req.Value, Type: req.Type,
GroupID: req.GroupID, Value: req.Value,
ValidityDays: req.ValidityDays, GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if execErr != nil {
return nil, execErr
}
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
return out, nil
}) })
if err != nil { }
response.ErrorFrom(c, err)
// CreateAndRedeem creates a fixed redeem code and redeems it for a target user in one step.
// POST /api/v1/admin/redeem-codes/create-and-redeem
func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
if h.redeemService == nil {
response.InternalError(c, "redeem service not configured")
return return
} }
out := make([]dto.AdminRedeemCode, 0, len(codes)) var req CreateAndRedeemCodeRequest
for i := range codes { if err := c.ShouldBindJSON(&req); err != nil {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.Code = strings.TrimSpace(req.Code)
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
existing, err := h.redeemService.GetByCode(ctx, req.Code)
if err == nil {
return h.resolveCreateAndRedeemExisting(ctx, existing, req.UserID)
}
if !errors.Is(err, service.ErrRedeemCodeNotFound) {
return nil, err
}
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
Code: req.Code,
Type: req.Type,
Value: req.Value,
Status: service.StatusUnused,
Notes: req.Notes,
})
if createErr != nil {
// Unique code race: if code now exists, use idempotent semantics by used_by.
existingAfterCreateErr, getErr := h.redeemService.GetByCode(ctx, req.Code)
if getErr == nil {
return h.resolveCreateAndRedeemExisting(ctx, existingAfterCreateErr, req.UserID)
}
return nil, createErr
}
redeemed, redeemErr := h.redeemService.Redeem(ctx, req.UserID, req.Code)
if redeemErr != nil {
return nil, redeemErr
}
return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil
})
}
func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, existing *service.RedeemCode, userID int64) (any, error) {
if existing == nil {
return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code conflict")
} }
response.Success(c, out)
// If previous run created the code but crashed before redeem, redeem it now.
if existing.CanUse() {
redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code)
if err == nil {
return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil
}
if !errors.Is(err, service.ErrRedeemCodeUsed) {
return nil, err
}
latest, getErr := h.redeemService.GetByCode(ctx, existing.Code)
if getErr == nil {
existing = latest
}
}
if existing.UsedBy != nil && *existing.UsedBy == userID {
return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(existing)}, nil
}
return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code already used by another user")
} }
// Delete handles deleting a redeem code // Delete handles deleting a redeem code
......
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
func truncateSearchByRune(search string, maxRunes int) string {
if runes := []rune(search); len(runes) > maxRunes {
return string(runes[:maxRunes])
}
return search
}
func TestTruncateSearchByRune(t *testing.T) {
tests := []struct {
name string
input string
maxRunes int
wantLen int // 期望的 rune 长度
}{
{
name: "纯中文超长",
input: string(make([]rune, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "纯 ASCII 超长",
input: string(make([]byte, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "空字符串",
input: "",
maxRunes: 100,
wantLen: 0,
},
{
name: "恰好 100 个字符",
input: string(make([]rune, 100)),
maxRunes: 100,
wantLen: 100,
},
{
name: "不足 100 字符不截断",
input: "hello世界",
maxRunes: 100,
wantLen: 7,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := truncateSearchByRune(tc.input, tc.maxRunes)
require.Equal(t, tc.wantLen, len([]rune(result)))
})
}
}
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
input := ""
for i := 0; i < 101; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
require.Equal(t, 100, len([]rune(result)))
// 验证截断结果是有效的 UTF-8(每个中文字符 3 字节)
require.Equal(t, 300, len(result))
}
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
// 50 个 ASCII + 51 个中文 = 101 个 rune
input := ""
for i := 0; i < 50; i++ {
input += "a"
}
for i := 0; i < 51; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
runes := []rune(result)
require.Equal(t, 100, len(runes))
// 前 50 个应该是 'a',后 50 个应该是 '中'
require.Equal(t, 'a', runes[0])
require.Equal(t, 'a', runes[49])
require.Equal(t, '中', runes[50])
require.Equal(t, '中', runes[99])
}
package admin package admin
import ( import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"log" "log"
"net/http"
"regexp"
"strings" "strings"
"time" "time"
...@@ -14,21 +20,38 @@ import ( ...@@ -14,21 +20,38 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// semverPattern 预编译 semver 格式校验正则
var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`)
// menuItemIDPattern validates custom menu item IDs: alphanumeric, hyphens, underscores only.
var menuItemIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
// generateMenuItemID generates a short random hex ID for a custom menu item.
func generateMenuItemID() (string, error) {
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate menu item ID: %w", err)
}
return hex.EncodeToString(b), nil
}
// SettingHandler 系统设置处理器 // SettingHandler 系统设置处理器
type SettingHandler struct { type SettingHandler struct {
settingService *service.SettingService settingService *service.SettingService
emailService *service.EmailService emailService *service.EmailService
turnstileService *service.TurnstileService turnstileService *service.TurnstileService
opsService *service.OpsService opsService *service.OpsService
soraS3Storage *service.SoraS3Storage
} }
// NewSettingHandler 创建系统设置处理器 // NewSettingHandler 创建系统设置处理器
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler { func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler {
return &SettingHandler{ return &SettingHandler{
settingService: settingService, settingService: settingService,
emailService: emailService, emailService: emailService,
turnstileService: turnstileService, turnstileService: turnstileService,
opsService: opsService, opsService: opsService,
soraS3Storage: soraS3Storage,
} }
} }
...@@ -43,10 +66,18 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -43,10 +66,18 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// Check if ops monitoring is enabled (respects config.ops.enabled) // Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context()) opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
defaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(settings.DefaultSubscriptions))
for _, sub := range settings.DefaultSubscriptions {
defaultSubscriptions = append(defaultSubscriptions, dto.DefaultSubscriptionSetting{
GroupID: sub.GroupID,
ValidityDays: sub.ValidityDays,
})
}
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled, RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled, PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled, PasswordResetEnabled: settings.PasswordResetEnabled,
InvitationCodeEnabled: settings.InvitationCodeEnabled, InvitationCodeEnabled: settings.InvitationCodeEnabled,
...@@ -76,8 +107,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -76,8 +107,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
HideCcsImportButton: settings.HideCcsImportButton, HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
DefaultConcurrency: settings.DefaultConcurrency, DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance, DefaultBalance: settings.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback, EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic, FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI, FallbackModelOpenAI: settings.FallbackModelOpenAI,
...@@ -89,18 +123,21 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ...@@ -89,18 +123,21 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled, OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: settings.OpsQueryModeDefault, OpsQueryModeDefault: settings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
}) })
} }
// UpdateSettingsRequest 更新设置请求 // UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct { type UpdateSettingsRequest struct {
// 注册设置 // 注册设置
RegistrationEnabled bool `json:"registration_enabled"` RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"`
PromoCodeEnabled bool `json:"promo_code_enabled"` RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PasswordResetEnabled bool `json:"password_reset_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"`
InvitationCodeEnabled bool `json:"invitation_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 InvitationCodeEnabled bool `json:"invitation_code_enabled"`
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置 // 邮件服务设置
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`
...@@ -123,20 +160,23 @@ type UpdateSettingsRequest struct { ...@@ -123,20 +160,23 @@ type UpdateSettingsRequest struct {
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置 // OEM设置
SiteName string `json:"site_name"` SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"` SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"` SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"` APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"` ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"` DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"` HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"` HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
// 默认配置 // 默认配置
DefaultConcurrency int `json:"default_concurrency"` DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"` DefaultBalance float64 `json:"default_balance"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration // Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"` EnableModelFallback bool `json:"enable_model_fallback"`
...@@ -154,6 +194,11 @@ type UpdateSettingsRequest struct { ...@@ -154,6 +194,11 @@ type UpdateSettingsRequest struct {
OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"` OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault *string `json:"ops_query_mode_default"` OpsQueryModeDefault *string `json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"` OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"`
MinClaudeCodeVersion string `json:"min_claude_code_version"`
// 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
} }
// UpdateSettings 更新系统设置 // UpdateSettings 更新系统设置
...@@ -181,6 +226,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -181,6 +226,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.SMTPPort <= 0 { if req.SMTPPort <= 0 {
req.SMTPPort = 587 req.SMTPPort = 587
} }
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
// Turnstile 参数验证 // Turnstile 参数验证
if req.TurnstileEnabled { if req.TurnstileEnabled {
...@@ -276,6 +322,84 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -276,6 +322,84 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
} }
// 自定义菜单项验证
const (
maxCustomMenuItems = 20
maxMenuItemLabelLen = 50
maxMenuItemURLLen = 2048
maxMenuItemIconSVGLen = 10 * 1024 // 10KB
maxMenuItemIDLen = 32
)
customMenuJSON := previousSettings.CustomMenuItems
if req.CustomMenuItems != nil {
items := *req.CustomMenuItems
if len(items) > maxCustomMenuItems {
response.BadRequest(c, "Too many custom menu items (max 20)")
return
}
for i, item := range items {
if strings.TrimSpace(item.Label) == "" {
response.BadRequest(c, "Custom menu item label is required")
return
}
if len(item.Label) > maxMenuItemLabelLen {
response.BadRequest(c, "Custom menu item label is too long (max 50 characters)")
return
}
if strings.TrimSpace(item.URL) == "" {
response.BadRequest(c, "Custom menu item URL is required")
return
}
if len(item.URL) > maxMenuItemURLLen {
response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
return
}
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil {
response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL")
return
}
if item.Visibility != "user" && item.Visibility != "admin" {
response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'")
return
}
if len(item.IconSVG) > maxMenuItemIconSVGLen {
response.BadRequest(c, "Custom menu item icon SVG is too large (max 10KB)")
return
}
// Auto-generate ID if missing
if strings.TrimSpace(item.ID) == "" {
id, err := generateMenuItemID()
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to generate menu item ID")
return
}
items[i].ID = id
} else if len(item.ID) > maxMenuItemIDLen {
response.BadRequest(c, "Custom menu item ID is too long (max 32 characters)")
return
} else if !menuItemIDPattern.MatchString(item.ID) {
response.BadRequest(c, "Custom menu item ID contains invalid characters (only a-z, A-Z, 0-9, - and _ are allowed)")
return
}
}
// ID uniqueness check
seen := make(map[string]struct{}, len(items))
for _, item := range items {
if _, exists := seen[item.ID]; exists {
response.BadRequest(c, "Duplicate custom menu item ID: "+item.ID)
return
}
seen[item.ID] = struct{}{}
}
menuBytes, err := json.Marshal(items)
if err != nil {
response.BadRequest(c, "Failed to serialize custom menu items")
return
}
customMenuJSON = string(menuBytes)
}
// Ops metrics collector interval validation (seconds). // Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil { if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds v := *req.OpsMetricsIntervalSeconds
...@@ -287,47 +411,68 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -287,47 +411,68 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
req.OpsMetricsIntervalSeconds = &v req.OpsMetricsIntervalSeconds = &v
} }
defaultSubscriptions := make([]service.DefaultSubscriptionSetting, 0, len(req.DefaultSubscriptions))
for _, sub := range req.DefaultSubscriptions {
defaultSubscriptions = append(defaultSubscriptions, service.DefaultSubscriptionSetting{
GroupID: sub.GroupID,
ValidityDays: sub.ValidityDays,
})
}
// 验证最低版本号格式(空字符串=禁用,或合法 semver)
if req.MinClaudeCodeVersion != "" {
if !semverPattern.MatchString(req.MinClaudeCodeVersion) {
response.Error(c, http.StatusBadRequest, "min_claude_code_version must be empty or a valid semver (e.g. 2.1.63)")
return
}
}
settings := &service.SystemSettings{ settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled, RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled, RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PasswordResetEnabled: req.PasswordResetEnabled, PromoCodeEnabled: req.PromoCodeEnabled,
InvitationCodeEnabled: req.InvitationCodeEnabled, PasswordResetEnabled: req.PasswordResetEnabled,
TotpEnabled: req.TotpEnabled, InvitationCodeEnabled: req.InvitationCodeEnabled,
SMTPHost: req.SMTPHost, TotpEnabled: req.TotpEnabled,
SMTPPort: req.SMTPPort, SMTPHost: req.SMTPHost,
SMTPUsername: req.SMTPUsername, SMTPPort: req.SMTPPort,
SMTPPassword: req.SMTPPassword, SMTPUsername: req.SMTPUsername,
SMTPFrom: req.SMTPFrom, SMTPPassword: req.SMTPPassword,
SMTPFromName: req.SMTPFromName, SMTPFrom: req.SMTPFrom,
SMTPUseTLS: req.SMTPUseTLS, SMTPFromName: req.SMTPFromName,
TurnstileEnabled: req.TurnstileEnabled, SMTPUseTLS: req.SMTPUseTLS,
TurnstileSiteKey: req.TurnstileSiteKey, TurnstileEnabled: req.TurnstileEnabled,
TurnstileSecretKey: req.TurnstileSecretKey, TurnstileSiteKey: req.TurnstileSiteKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectClientID: req.LinuxDoConnectClientID, LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
SiteName: req.SiteName, LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
SiteLogo: req.SiteLogo, SiteName: req.SiteName,
SiteSubtitle: req.SiteSubtitle, SiteLogo: req.SiteLogo,
APIBaseURL: req.APIBaseURL, SiteSubtitle: req.SiteSubtitle,
ContactInfo: req.ContactInfo, APIBaseURL: req.APIBaseURL,
DocURL: req.DocURL, ContactInfo: req.ContactInfo,
HomeContent: req.HomeContent, DocURL: req.DocURL,
HideCcsImportButton: req.HideCcsImportButton, HomeContent: req.HomeContent,
PurchaseSubscriptionEnabled: purchaseEnabled, HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionURL: purchaseURL, PurchaseSubscriptionEnabled: purchaseEnabled,
DefaultConcurrency: req.DefaultConcurrency, PurchaseSubscriptionURL: purchaseURL,
DefaultBalance: req.DefaultBalance, SoraClientEnabled: req.SoraClientEnabled,
EnableModelFallback: req.EnableModelFallback, CustomMenuItems: customMenuJSON,
FallbackModelAnthropic: req.FallbackModelAnthropic, DefaultConcurrency: req.DefaultConcurrency,
FallbackModelOpenAI: req.FallbackModelOpenAI, DefaultBalance: req.DefaultBalance,
FallbackModelGemini: req.FallbackModelGemini, DefaultSubscriptions: defaultSubscriptions,
FallbackModelAntigravity: req.FallbackModelAntigravity, EnableModelFallback: req.EnableModelFallback,
EnableIdentityPatch: req.EnableIdentityPatch, FallbackModelAnthropic: req.FallbackModelAnthropic,
IdentityPatchPrompt: req.IdentityPatchPrompt, FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
OpsMonitoringEnabled: func() bool { OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil { if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled return *req.OpsMonitoringEnabled
...@@ -367,10 +512,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -367,10 +512,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
for _, sub := range updatedSettings.DefaultSubscriptions {
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
GroupID: sub.GroupID,
ValidityDays: sub.ValidityDays,
})
}
response.Success(c, dto.SystemSettings{ response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled, RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled, PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled, PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled, InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
...@@ -400,8 +553,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -400,8 +553,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: updatedSettings.HideCcsImportButton, HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
SoraClientEnabled: updatedSettings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance, DefaultBalance: updatedSettings.DefaultBalance,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback, EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
...@@ -413,6 +569,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ...@@ -413,6 +569,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled, OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
}) })
} }
...@@ -444,6 +602,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -444,6 +602,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EmailVerifyEnabled != after.EmailVerifyEnabled { if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled") changed = append(changed, "email_verify_enabled")
} }
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
changed = append(changed, "registration_email_suffix_whitelist")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled { if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled") changed = append(changed, "password_reset_enabled")
} }
...@@ -522,6 +683,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -522,6 +683,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.DefaultBalance != after.DefaultBalance { if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance") changed = append(changed, "default_balance")
} }
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions")
}
if before.EnableModelFallback != after.EnableModelFallback { if before.EnableModelFallback != after.EnableModelFallback {
changed = append(changed, "enable_model_fallback") changed = append(changed, "enable_model_fallback")
} }
...@@ -555,9 +719,65 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, ...@@ -555,9 +719,65 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.OpsMetricsIntervalSeconds != after.OpsMetricsIntervalSeconds { if before.OpsMetricsIntervalSeconds != after.OpsMetricsIntervalSeconds {
changed = append(changed, "ops_metrics_interval_seconds") changed = append(changed, "ops_metrics_interval_seconds")
} }
if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion {
changed = append(changed, "min_claude_code_version")
}
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
changed = append(changed, "allow_ungrouped_key_scheduling")
}
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
changed = append(changed, "purchase_subscription_enabled")
}
if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL {
changed = append(changed, "purchase_subscription_url")
}
if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items")
}
return changed return changed
} }
func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto.DefaultSubscriptionSetting {
if len(input) == 0 {
return nil
}
normalized := make([]dto.DefaultSubscriptionSetting, 0, len(input))
for _, item := range input {
if item.GroupID <= 0 || item.ValidityDays <= 0 {
continue
}
if item.ValidityDays > service.MaxValidityDays {
item.ValidityDays = service.MaxValidityDays
}
normalized = append(normalized, item)
}
return normalized
}
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].GroupID != b[i].GroupID || a[i].ValidityDays != b[i].ValidityDays {
return false
}
}
return true
}
// TestSMTPRequest 测试SMTP连接请求 // TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct { type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"` SMTPHost string `json:"smtp_host" binding:"required"`
...@@ -750,6 +970,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) { ...@@ -750,6 +970,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
}) })
} }
func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings {
if settings == nil {
return dto.SoraS3Settings{}
}
return dto.SoraS3Settings{
Enabled: settings.Enabled,
Endpoint: settings.Endpoint,
Region: settings.Region,
Bucket: settings.Bucket,
AccessKeyID: settings.AccessKeyID,
SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured,
Prefix: settings.Prefix,
ForcePathStyle: settings.ForcePathStyle,
CDNURL: settings.CDNURL,
DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes,
}
}
func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile {
return dto.SoraS3Profile{
ProfileID: profile.ProfileID,
Name: profile.Name,
IsActive: profile.IsActive,
Enabled: profile.Enabled,
Endpoint: profile.Endpoint,
Region: profile.Region,
Bucket: profile.Bucket,
AccessKeyID: profile.AccessKeyID,
SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured,
Prefix: profile.Prefix,
ForcePathStyle: profile.ForcePathStyle,
CDNURL: profile.CDNURL,
DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes,
UpdatedAt: profile.UpdatedAt,
}
}
func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error {
if !enabled {
return nil
}
if strings.TrimSpace(endpoint) == "" {
return fmt.Errorf("S3 Endpoint is required when enabled")
}
if strings.TrimSpace(bucket) == "" {
return fmt.Errorf("S3 Bucket is required when enabled")
}
if strings.TrimSpace(accessKeyID) == "" {
return fmt.Errorf("S3 Access Key ID is required when enabled")
}
if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret {
return nil
}
return fmt.Errorf("S3 Secret Access Key is required when enabled")
}
func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile {
for idx := range items {
if items[idx].ProfileID == profileID {
return &items[idx]
}
}
return nil
}
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口)
// GET /api/v1/admin/settings/sora-s3
func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) {
settings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, toSoraS3SettingsDTO(settings))
}
// ListSoraS3Profiles 获取 Sora S3 多配置
// GET /api/v1/admin/settings/sora-s3/profiles
func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) {
result, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
items := make([]dto.SoraS3Profile, 0, len(result.Items))
for idx := range result.Items {
items = append(items, toSoraS3ProfileDTO(result.Items[idx]))
}
response.Success(c, dto.ListSoraS3ProfilesResponse{
ActiveProfileID: result.ActiveProfileID,
Items: items,
})
}
// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口)
type UpdateSoraS3SettingsRequest struct {
ProfileID string `json:"profile_id"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
type CreateSoraS3ProfileRequest struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
SetActive bool `json:"set_active"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
type UpdateSoraS3ProfileRequest struct {
Name string `json:"name"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
// CreateSoraS3Profile 创建 Sora S3 配置
// POST /api/v1/admin/settings/sora-s3/profiles
func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) {
var req CreateSoraS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.DefaultStorageQuotaBytes < 0 {
req.DefaultStorageQuotaBytes = 0
}
if strings.TrimSpace(req.Name) == "" {
response.BadRequest(c, "Name is required")
return
}
if strings.TrimSpace(req.ProfileID) == "" {
response.BadRequest(c, "Profile ID is required")
return
}
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil {
response.BadRequest(c, err.Error())
return
}
created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{
ProfileID: req.ProfileID,
Name: req.Name,
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
CDNURL: req.CDNURL,
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
}, req.SetActive)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, toSoraS3ProfileDTO(*created))
}
// UpdateSoraS3Profile 更新 Sora S3 配置
// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id
func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Profile ID is required")
return
}
var req UpdateSoraS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.DefaultStorageQuotaBytes < 0 {
req.DefaultStorageQuotaBytes = 0
}
if strings.TrimSpace(req.Name) == "" {
response.BadRequest(c, "Name is required")
return
}
existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
existing := findSoraS3ProfileByID(existingList.Items, profileID)
if existing == nil {
response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound)
return
}
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
response.BadRequest(c, err.Error())
return
}
updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{
Name: req.Name,
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
CDNURL: req.CDNURL,
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
})
if updateErr != nil {
response.ErrorFrom(c, updateErr)
return
}
response.Success(c, toSoraS3ProfileDTO(*updated))
}
// DeleteSoraS3Profile 删除 Sora S3 配置
// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id
func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Profile ID is required")
return
}
if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
// SetActiveSoraS3Profile 切换激活 Sora S3 配置
// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate
func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Profile ID is required")
return
}
active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, toSoraS3ProfileDTO(*active))
}
// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口)
// PUT /api/v1/admin/settings/sora-s3
func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) {
var req UpdateSoraS3SettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
if req.DefaultStorageQuotaBytes < 0 {
req.DefaultStorageQuotaBytes = 0
}
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
response.BadRequest(c, err.Error())
return
}
settings := &service.SoraS3Settings{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
CDNURL: req.CDNURL,
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
}
if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil {
response.ErrorFrom(c, err)
return
}
updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, toSoraS3SettingsDTO(updatedSettings))
}
// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket)
// POST /api/v1/admin/settings/sora-s3/test
func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
if h.soraS3Storage == nil {
response.Error(c, 500, "S3 存储服务未初始化")
return
}
var req UpdateSoraS3SettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !req.Enabled {
response.BadRequest(c, "S3 未启用,无法测试连接")
return
}
if req.SecretAccessKey == "" {
if req.ProfileID != "" {
profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
if err == nil {
profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID)
if profile != nil {
req.SecretAccessKey = profile.SecretAccessKey
}
}
}
if req.SecretAccessKey == "" {
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
if err == nil {
req.SecretAccessKey = existing.SecretAccessKey
}
}
}
testCfg := &service.SoraS3Settings{
Enabled: true,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
CDNURL: req.CDNURL,
}
if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil {
response.Error(c, 400, "S3 连接测试失败: "+err.Error())
return
}
response.Success(c, gin.H{"message": "S3 连接成功"})
}
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求 // UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
type UpdateStreamTimeoutSettingsRequest struct { type UpdateStreamTimeoutSettingsRequest struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
......
package admin
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"strings"
"sync"
"time"
)
type snapshotCacheEntry struct {
ETag string
Payload any
ExpiresAt time.Time
}
type snapshotCache struct {
mu sync.RWMutex
ttl time.Duration
items map[string]snapshotCacheEntry
}
func newSnapshotCache(ttl time.Duration) *snapshotCache {
if ttl <= 0 {
ttl = 30 * time.Second
}
return &snapshotCache{
ttl: ttl,
items: make(map[string]snapshotCacheEntry),
}
}
func (c *snapshotCache) Get(key string) (snapshotCacheEntry, bool) {
if c == nil || key == "" {
return snapshotCacheEntry{}, false
}
now := time.Now()
c.mu.RLock()
entry, ok := c.items[key]
c.mu.RUnlock()
if !ok {
return snapshotCacheEntry{}, false
}
if now.After(entry.ExpiresAt) {
c.mu.Lock()
delete(c.items, key)
c.mu.Unlock()
return snapshotCacheEntry{}, false
}
return entry, true
}
func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry {
if c == nil {
return snapshotCacheEntry{}
}
entry := snapshotCacheEntry{
ETag: buildETagFromAny(payload),
Payload: payload,
ExpiresAt: time.Now().Add(c.ttl),
}
if key == "" {
return entry
}
c.mu.Lock()
c.items[key] = entry
c.mu.Unlock()
return entry
}
func buildETagFromAny(payload any) string {
raw, err := json.Marshal(payload)
if err != nil {
return ""
}
sum := sha256.Sum256(raw)
return "\"" + hex.EncodeToString(sum[:]) + "\""
}
func parseBoolQueryWithDefault(raw string, def bool) bool {
value := strings.TrimSpace(strings.ToLower(raw))
if value == "" {
return def
}
switch value {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
default:
return def
}
}
//go:build unit
package admin
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestSnapshotCache_SetAndGet(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
entry := c.Set("key1", map[string]string{"hello": "world"})
require.NotEmpty(t, entry.ETag)
require.NotNil(t, entry.Payload)
got, ok := c.Get("key1")
require.True(t, ok)
require.Equal(t, entry.ETag, got.ETag)
}
func TestSnapshotCache_Expiration(t *testing.T) {
c := newSnapshotCache(1 * time.Millisecond)
c.Set("key1", "value")
time.Sleep(5 * time.Millisecond)
_, ok := c.Get("key1")
require.False(t, ok, "expired entry should not be returned")
}
func TestSnapshotCache_GetEmptyKey(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
_, ok := c.Get("")
require.False(t, ok)
}
func TestSnapshotCache_GetMiss(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
_, ok := c.Get("nonexistent")
require.False(t, ok)
}
func TestSnapshotCache_NilReceiver(t *testing.T) {
var c *snapshotCache
_, ok := c.Get("key")
require.False(t, ok)
entry := c.Set("key", "value")
require.Empty(t, entry.ETag)
}
func TestSnapshotCache_SetEmptyKey(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
// Set with empty key should return entry but not store it
entry := c.Set("", "value")
require.NotEmpty(t, entry.ETag)
_, ok := c.Get("")
require.False(t, ok)
}
func TestSnapshotCache_DefaultTTL(t *testing.T) {
c := newSnapshotCache(0)
require.Equal(t, 30*time.Second, c.ttl)
c2 := newSnapshotCache(-1 * time.Second)
require.Equal(t, 30*time.Second, c2.ttl)
}
func TestSnapshotCache_ETagDeterministic(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
payload := map[string]int{"a": 1, "b": 2}
entry1 := c.Set("k1", payload)
entry2 := c.Set("k2", payload)
require.Equal(t, entry1.ETag, entry2.ETag, "same payload should produce same ETag")
}
func TestSnapshotCache_ETagFormat(t *testing.T) {
c := newSnapshotCache(5 * time.Second)
entry := c.Set("k", "test")
// ETag should be quoted hex string: "abcdef..."
require.True(t, len(entry.ETag) > 2)
require.Equal(t, byte('"'), entry.ETag[0])
require.Equal(t, byte('"'), entry.ETag[len(entry.ETag)-1])
}
func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) {
// channels are not JSON-serializable
etag := buildETagFromAny(make(chan int))
require.Empty(t, etag)
}
func TestParseBoolQueryWithDefault(t *testing.T) {
tests := []struct {
name string
raw string
def bool
want bool
}{
{"empty returns default true", "", true, true},
{"empty returns default false", "", false, false},
{"1", "1", false, true},
{"true", "true", false, true},
{"TRUE", "TRUE", false, true},
{"yes", "yes", false, true},
{"on", "on", false, true},
{"0", "0", true, false},
{"false", "false", true, false},
{"FALSE", "FALSE", true, false},
{"no", "no", true, false},
{"off", "off", true, false},
{"whitespace trimmed", " true ", false, true},
{"unknown returns default true", "maybe", true, true},
{"unknown returns default false", "maybe", false, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := parseBoolQueryWithDefault(tc.raw, tc.def)
require.Equal(t, tc.want, got)
})
}
}
package admin package admin
import ( import (
"context"
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
...@@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { ...@@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return return
} }
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days) idempotencyPayload := struct {
if err != nil { SubscriptionID int64 `json:"subscription_id"`
response.ErrorFrom(c, err) Body AdjustSubscriptionRequest `json:"body"`
return }{
SubscriptionID: subscriptionID,
Body: req,
} }
executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription)) subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days)
if execErr != nil {
return nil, execErr
}
return dto.UserSubscriptionFromServiceAdmin(subscription), nil
})
} }
// Revoke handles revoking a subscription // Revoke handles revoking a subscription
......
package admin package admin
import ( import (
"context"
"net/http" "net/http"
"strconv"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil" "github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
...@@ -14,12 +18,14 @@ import ( ...@@ -14,12 +18,14 @@ import (
// SystemHandler handles system-related operations // SystemHandler handles system-related operations
type SystemHandler struct { type SystemHandler struct {
updateSvc *service.UpdateService updateSvc *service.UpdateService
lockSvc *service.SystemOperationLockService
} }
// NewSystemHandler creates a new SystemHandler // NewSystemHandler creates a new SystemHandler
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler { func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
return &SystemHandler{ return &SystemHandler{
updateSvc: updateSvc, updateSvc: updateSvc,
lockSvc: lockSvc,
} }
} }
...@@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) { ...@@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) {
// PerformUpdate downloads and applies the update // PerformUpdate downloads and applies the update
// POST /api/v1/admin/system/update // POST /api/v1/admin/system/update
func (h *SystemHandler) PerformUpdate(c *gin.Context) { func (h *SystemHandler) PerformUpdate(c *gin.Context) {
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil { operationID := buildSystemOperationID(c, "update")
response.Error(c, http.StatusInternalServerError, err.Error()) payload := gin.H{"operation_id": operationID}
return executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
} lock, release, err := h.acquireSystemLock(ctx, operationID)
response.Success(c, gin.H{ if err != nil {
"message": "Update completed. Please restart the service.", return nil, err
"need_restart": true, }
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
releaseReason = "SYSTEM_UPDATE_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
}) })
} }
// Rollback restores the previous version // Rollback restores the previous version
// POST /api/v1/admin/system/rollback // POST /api/v1/admin/system/rollback
func (h *SystemHandler) Rollback(c *gin.Context) { func (h *SystemHandler) Rollback(c *gin.Context) {
if err := h.updateSvc.Rollback(); err != nil { operationID := buildSystemOperationID(c, "rollback")
response.Error(c, http.StatusInternalServerError, err.Error()) payload := gin.H{"operation_id": operationID}
return executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
} lock, release, err := h.acquireSystemLock(ctx, operationID)
response.Success(c, gin.H{ if err != nil {
"message": "Rollback completed. Please restart the service.", return nil, err
"need_restart": true, }
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.Rollback(); err != nil {
releaseReason = "SYSTEM_ROLLBACK_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
}) })
} }
// RestartService restarts the systemd service // RestartService restarts the systemd service
// POST /api/v1/admin/system/restart // POST /api/v1/admin/system/restart
func (h *SystemHandler) RestartService(c *gin.Context) { func (h *SystemHandler) RestartService(c *gin.Context) {
// Schedule service restart in background after sending response operationID := buildSystemOperationID(c, "restart")
// This ensures the client receives the success response before the service restarts payload := gin.H{"operation_id": operationID}
go func() { executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
// Wait a moment to ensure the response is sent lock, release, err := h.acquireSystemLock(ctx, operationID)
time.Sleep(500 * time.Millisecond) if err != nil {
sysutil.RestartServiceAsync() return nil, err
}() }
succeeded := false
defer func() {
release("", succeeded)
}()
response.Success(c, gin.H{ // Schedule service restart in background after sending response
"message": "Service restart initiated", // This ensures the client receives the success response before the service restarts
go func() {
// Wait a moment to ensure the response is sent
time.Sleep(500 * time.Millisecond)
sysutil.RestartServiceAsync()
}()
succeeded = true
return gin.H{
"message": "Service restart initiated",
"operation_id": lock.OperationID(),
}, nil
}) })
} }
func (h *SystemHandler) acquireSystemLock(
ctx context.Context,
operationID string,
) (*service.SystemOperationLock, func(string, bool), error) {
if h.lockSvc == nil {
return nil, nil, service.ErrIdempotencyStoreUnavail
}
lock, err := h.lockSvc.Acquire(ctx, operationID)
if err != nil {
return nil, nil, err
}
release := func(reason string, succeeded bool) {
releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason)
}
return lock, release, nil
}
func buildSystemOperationID(c *gin.Context, operation string) string {
key := strings.TrimSpace(c.GetHeader("Idempotency-Key"))
if key == "" {
return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36)
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key
hash := service.HashIdempotencyKey(seed)
if len(hash) > 24 {
hash = hash[:24]
}
return "sysop-" + hash
}
...@@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) { ...@@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
require.Equal(t, http.StatusBadRequest, recorder.Code) require.Equal(t, http.StatusBadRequest, recorder.Code)
} }
func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-01-02",
"timezone": "UTC",
"request_type": "invalid",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-01-02",
"timezone": "UTC",
"request_type": "ws_v2",
"stream": false,
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.NotNil(t, created.Filters.RequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType)
require.Nil(t, created.Filters.Stream)
}
func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-01-02",
"timezone": "UTC",
"stream": true,
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.Nil(t, created.Filters.RequestType)
require.NotNil(t, created.Filters.Stream)
require.True(t, *created.Filters.Stream)
}
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) { func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{} repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment