Commit b9b4db3d authored by song's avatar song
Browse files

Merge upstream/main

parents 5a6f60a9 dae0d532
...@@ -16,7 +16,9 @@ type AdminHandlers struct { ...@@ -16,7 +16,9 @@ type AdminHandlers struct {
AntigravityOAuth *admin.AntigravityOAuthHandler AntigravityOAuth *admin.AntigravityOAuthHandler
Proxy *admin.ProxyHandler Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler Redeem *admin.RedeemHandler
Promo *admin.PromoHandler
Setting *admin.SettingHandler Setting *admin.SettingHandler
Ops *admin.OpsHandler
System *admin.SystemHandler System *admin.SystemHandler
Subscription *admin.SubscriptionHandler Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler Usage *admin.UsageHandler
......
...@@ -8,9 +8,11 @@ import ( ...@@ -8,9 +8,11 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -81,6 +83,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -81,6 +83,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
setOpsRequestContext(c, "", false, body)
// Parse request body to map for potential modification // Parse request body to map for potential modification
var reqBody map[string]any var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil { if err := json.Unmarshal(body, &reqBody); err != nil {
...@@ -98,15 +102,41 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -98,15 +102,41 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
// For non-Codex CLI requests, set default instructions
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
if !openai.IsCodexCLIRequest(userAgent) { if !openai.IsCodexCLIRequest(userAgent) {
reqBody["instructions"] = openai.DefaultInstructions existingInstructions, _ := reqBody["instructions"].(string)
// Re-serialize body if strings.TrimSpace(existingInstructions) == "" {
body, err = json.Marshal(reqBody) if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
if err != nil { reqBody["instructions"] = instructions
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") // Re-serialize body
return body, err = json.Marshal(reqBody)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
return
}
}
}
}
setOpsRequestContext(c, reqModel, reqStream, body)
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
// 或带 id 且与 call_id 匹配的 item_reference。
if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
if service.HasFunctionCallOutputMissingCallID(reqBody) {
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
return
}
callIDs := service.FunctionCallOutputCallIDs(reqBody)
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
return
}
} }
} }
...@@ -119,6 +149,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -119,6 +149,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 0. Check if wait queue is full // 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(subject.Concurrency) maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil { if err != nil {
log.Printf("Increment wait count failed: %v", err) log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed // On error, allow request to proceed
...@@ -126,8 +157,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -126,8 +157,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return return
} }
// Ensure wait count is decremented when function exits if err == nil && canWait {
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
// 1. First acquire user concurrency slot // 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
...@@ -136,6 +173,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -136,6 +173,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
return return
} }
// User slot acquired: no longer waiting.
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil { if userReleaseFunc != nil {
...@@ -173,15 +215,16 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -173,15 +215,16 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
account := selection.Account account := selection.Account
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
setOpsSelectedAccount(c, account.ID)
// 3. Acquire account concurrency slot // 3. Acquire account concurrency slot
accountReleaseFunc := selection.ReleaseFunc accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired { if !selection.Acquired {
if selection.WaitPlan == nil { if selection.WaitPlan == nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return return
} }
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil { if err != nil {
log.Printf("Increment account wait count failed: %v", err) log.Printf("Increment account wait count failed: %v", err)
...@@ -189,12 +232,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -189,12 +232,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("Account wait queue full: account=%d", account.ID) log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return return
} else { }
// Only set release function if increment succeeded if err == nil && canWait {
accountWaitRelease = func() { accountWaitCounted = true
}
defer func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
} }
} }()
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c, c,
...@@ -205,29 +251,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -205,29 +251,26 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
&streamStarted, &streamStarted,
) )
if err != nil { if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
return return
} }
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
// 账号槽位/等待计数需要在超时或断开时安全回收 // 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// Forward request // Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
...@@ -247,8 +290,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -247,8 +290,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
// Async record usage // Async record usage
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string) { go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
...@@ -258,10 +305,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ...@@ -258,10 +305,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
Account: usedAccount, Account: usedAccount,
Subscription: subscription, Subscription: subscription,
UserAgent: ua, UserAgent: ua,
IPAddress: ip,
}); err != nil { }); err != nil {
log.Printf("Record usage failed: %v", err) log.Printf("Record usage failed: %v", err)
} }
}(result, account, userAgent) }(result, account, userAgent, clientIP)
return return
} }
} }
......
package handler
import (
"bytes"
"context"
"encoding/json"
"log"
"runtime"
"runtime/debug"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"unicode/utf8"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
opsModelKey = "ops_model"
opsStreamKey = "ops_stream"
opsRequestBodyKey = "ops_request_body"
opsAccountIDKey = "ops_account_id"
)
const (
opsErrorLogTimeout = 5 * time.Second
opsErrorLogDrainTimeout = 10 * time.Second
opsErrorLogMinWorkerCount = 4
opsErrorLogMaxWorkerCount = 32
opsErrorLogQueueSizePerWorker = 128
opsErrorLogMinQueueSize = 256
opsErrorLogMaxQueueSize = 8192
)
type opsErrorLogJob struct {
ops *service.OpsService
entry *service.OpsInsertErrorLogInput
requestBody []byte
}
var (
opsErrorLogOnce sync.Once
opsErrorLogQueue chan opsErrorLogJob
opsErrorLogStopOnce sync.Once
opsErrorLogWorkersWg sync.WaitGroup
opsErrorLogMu sync.RWMutex
opsErrorLogStopping bool
opsErrorLogQueueLen atomic.Int64
opsErrorLogEnqueued atomic.Int64
opsErrorLogDropped atomic.Int64
opsErrorLogProcessed atomic.Int64
opsErrorLogLastDropLogAt atomic.Int64
opsErrorLogShutdownCh = make(chan struct{})
opsErrorLogShutdownOnce sync.Once
opsErrorLogDrained atomic.Bool
)
func startOpsErrorLogWorkers() {
opsErrorLogMu.Lock()
defer opsErrorLogMu.Unlock()
if opsErrorLogStopping {
return
}
workerCount, queueSize := opsErrorLogConfig()
opsErrorLogQueue = make(chan opsErrorLogJob, queueSize)
opsErrorLogQueueLen.Store(0)
opsErrorLogWorkersWg.Add(workerCount)
for i := 0; i < workerCount; i++ {
go func() {
defer opsErrorLogWorkersWg.Done()
for job := range opsErrorLogQueue {
opsErrorLogQueueLen.Add(-1)
if job.ops == nil || job.entry == nil {
continue
}
func() {
defer func() {
if r := recover(); r != nil {
log.Printf("[OpsErrorLogger] worker panic: %v\n%s", r, debug.Stack())
}
}()
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
_ = job.ops.RecordError(ctx, job.entry, job.requestBody)
cancel()
opsErrorLogProcessed.Add(1)
}()
}
}()
}
}
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) {
if ops == nil || entry == nil {
return
}
select {
case <-opsErrorLogShutdownCh:
return
default:
}
opsErrorLogMu.RLock()
stopping := opsErrorLogStopping
opsErrorLogMu.RUnlock()
if stopping {
return
}
opsErrorLogOnce.Do(startOpsErrorLogWorkers)
opsErrorLogMu.RLock()
defer opsErrorLogMu.RUnlock()
if opsErrorLogStopping || opsErrorLogQueue == nil {
return
}
select {
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}:
opsErrorLogQueueLen.Add(1)
opsErrorLogEnqueued.Add(1)
default:
// Queue is full; drop to avoid blocking request handling.
opsErrorLogDropped.Add(1)
maybeLogOpsErrorLogDrop()
}
}
func StopOpsErrorLogWorkers() bool {
opsErrorLogStopOnce.Do(func() {
opsErrorLogShutdownOnce.Do(func() {
close(opsErrorLogShutdownCh)
})
opsErrorLogDrained.Store(stopOpsErrorLogWorkers())
})
return opsErrorLogDrained.Load()
}
func stopOpsErrorLogWorkers() bool {
opsErrorLogMu.Lock()
opsErrorLogStopping = true
ch := opsErrorLogQueue
if ch != nil {
close(ch)
}
opsErrorLogQueue = nil
opsErrorLogMu.Unlock()
if ch == nil {
opsErrorLogQueueLen.Store(0)
return true
}
done := make(chan struct{})
go func() {
opsErrorLogWorkersWg.Wait()
close(done)
}()
select {
case <-done:
opsErrorLogQueueLen.Store(0)
return true
case <-time.After(opsErrorLogDrainTimeout):
return false
}
}
func OpsErrorLogQueueLength() int64 {
return opsErrorLogQueueLen.Load()
}
func OpsErrorLogQueueCapacity() int {
opsErrorLogMu.RLock()
ch := opsErrorLogQueue
opsErrorLogMu.RUnlock()
if ch == nil {
return 0
}
return cap(ch)
}
func OpsErrorLogDroppedTotal() int64 {
return opsErrorLogDropped.Load()
}
func OpsErrorLogEnqueuedTotal() int64 {
return opsErrorLogEnqueued.Load()
}
func OpsErrorLogProcessedTotal() int64 {
return opsErrorLogProcessed.Load()
}
func maybeLogOpsErrorLogDrop() {
now := time.Now().Unix()
for {
last := opsErrorLogLastDropLogAt.Load()
if last != 0 && now-last < 60 {
return
}
if opsErrorLogLastDropLogAt.CompareAndSwap(last, now) {
break
}
}
queued := opsErrorLogQueueLen.Load()
queueCap := OpsErrorLogQueueCapacity()
log.Printf(
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)",
queued,
queueCap,
opsErrorLogEnqueued.Load(),
opsErrorLogDropped.Load(),
opsErrorLogProcessed.Load(),
)
}
func opsErrorLogConfig() (workerCount int, queueSize int) {
workerCount = runtime.GOMAXPROCS(0) * 2
if workerCount < opsErrorLogMinWorkerCount {
workerCount = opsErrorLogMinWorkerCount
}
if workerCount > opsErrorLogMaxWorkerCount {
workerCount = opsErrorLogMaxWorkerCount
}
queueSize = workerCount * opsErrorLogQueueSizePerWorker
if queueSize < opsErrorLogMinQueueSize {
queueSize = opsErrorLogMinQueueSize
}
if queueSize > opsErrorLogMaxQueueSize {
queueSize = opsErrorLogMaxQueueSize
}
return workerCount, queueSize
}
func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody []byte) {
if c == nil {
return
}
c.Set(opsModelKey, model)
c.Set(opsStreamKey, stream)
if len(requestBody) > 0 {
c.Set(opsRequestBodyKey, requestBody)
}
}
func setOpsSelectedAccount(c *gin.Context, accountID int64) {
if c == nil || accountID <= 0 {
return
}
c.Set(opsAccountIDKey, accountID)
}
type opsCaptureWriter struct {
gin.ResponseWriter
limit int
buf bytes.Buffer
}
func (w *opsCaptureWriter) Write(b []byte) (int, error) {
if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit {
remaining := w.limit - w.buf.Len()
if len(b) > remaining {
_, _ = w.buf.Write(b[:remaining])
} else {
_, _ = w.buf.Write(b)
}
}
return w.ResponseWriter.Write(b)
}
func (w *opsCaptureWriter) WriteString(s string) (int, error) {
if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit {
remaining := w.limit - w.buf.Len()
if len(s) > remaining {
_, _ = w.buf.WriteString(s[:remaining])
} else {
_, _ = w.buf.WriteString(s)
}
}
return w.ResponseWriter.WriteString(s)
}
// OpsErrorLoggerMiddleware records error responses (status >= 400) into ops_error_logs.
//
// Notes:
// - It buffers response bodies only when status >= 400 to avoid overhead for successful traffic.
// - Streaming errors after the response has started (SSE) may still need explicit logging.
func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
return func(c *gin.Context) {
w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024}
c.Writer = w
c.Next()
if ops == nil {
return
}
if !ops.IsMonitoringEnabled(c.Request.Context()) {
return
}
status := c.Writer.Status()
if status < 400 {
// Even when the client request succeeds, we still want to persist upstream error attempts
// (retries/failover) so ops can observe upstream instability that gets "covered" by retries.
var events []*service.OpsUpstreamErrorEvent
if v, ok := c.Get(service.OpsUpstreamErrorsKey); ok {
if arr, ok := v.([]*service.OpsUpstreamErrorEvent); ok && len(arr) > 0 {
events = arr
}
}
// Also accept single upstream fields set by gateway services (rare for successful requests).
hasUpstreamContext := len(events) > 0
if !hasUpstreamContext {
if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok {
switch t := v.(type) {
case int:
hasUpstreamContext = t > 0
case int64:
hasUpstreamContext = t > 0
}
}
}
if !hasUpstreamContext {
if v, ok := c.Get(service.OpsUpstreamErrorMessageKey); ok {
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
hasUpstreamContext = true
}
}
}
if !hasUpstreamContext {
if v, ok := c.Get(service.OpsUpstreamErrorDetailKey); ok {
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
hasUpstreamContext = true
}
}
}
if !hasUpstreamContext {
return
}
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
model, _ := c.Get(opsModelKey)
streamV, _ := c.Get(opsStreamKey)
accountIDV, _ := c.Get(opsAccountIDKey)
var modelName string
if s, ok := model.(string); ok {
modelName = s
}
stream := false
if b, ok := streamV.(bool); ok {
stream = b
}
// Prefer showing the account that experienced the upstream error (if we have events),
// otherwise fall back to the final selected account (best-effort).
var accountID *int64
if len(events) > 0 {
if last := events[len(events)-1]; last != nil && last.AccountID > 0 {
v := last.AccountID
accountID = &v
}
}
if accountID == nil {
if v, ok := accountIDV.(int64); ok && v > 0 {
accountID = &v
}
}
fallbackPlatform := guessPlatformFromPath(c.Request.URL.Path)
platform := resolveOpsPlatform(apiKey, fallbackPlatform)
requestID := c.Writer.Header().Get("X-Request-Id")
if requestID == "" {
requestID = c.Writer.Header().Get("x-request-id")
}
// Best-effort backfill single upstream fields from the last event (if present).
var upstreamStatusCode *int
var upstreamErrorMessage *string
var upstreamErrorDetail *string
if len(events) > 0 {
last := events[len(events)-1]
if last != nil {
if last.UpstreamStatusCode > 0 {
code := last.UpstreamStatusCode
upstreamStatusCode = &code
}
if msg := strings.TrimSpace(last.Message); msg != "" {
upstreamErrorMessage = &msg
}
if detail := strings.TrimSpace(last.Detail); detail != "" {
upstreamErrorDetail = &detail
}
}
}
if upstreamStatusCode == nil {
if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok {
switch t := v.(type) {
case int:
if t > 0 {
code := t
upstreamStatusCode = &code
}
case int64:
if t > 0 {
code := int(t)
upstreamStatusCode = &code
}
}
}
}
if upstreamErrorMessage == nil {
if v, ok := c.Get(service.OpsUpstreamErrorMessageKey); ok {
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
msg := strings.TrimSpace(s)
upstreamErrorMessage = &msg
}
}
}
if upstreamErrorDetail == nil {
if v, ok := c.Get(service.OpsUpstreamErrorDetailKey); ok {
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
detail := strings.TrimSpace(s)
upstreamErrorDetail = &detail
}
}
}
// If we still have nothing meaningful, skip.
if upstreamStatusCode == nil && upstreamErrorMessage == nil && upstreamErrorDetail == nil && len(events) == 0 {
return
}
effectiveUpstreamStatus := 0
if upstreamStatusCode != nil {
effectiveUpstreamStatus = *upstreamStatusCode
}
recoveredMsg := "Recovered upstream error"
if effectiveUpstreamStatus > 0 {
recoveredMsg += " " + strconvItoa(effectiveUpstreamStatus)
}
if upstreamErrorMessage != nil && strings.TrimSpace(*upstreamErrorMessage) != "" {
recoveredMsg += ": " + strings.TrimSpace(*upstreamErrorMessage)
}
recoveredMsg = truncateString(recoveredMsg, 2048)
entry := &service.OpsInsertErrorLogInput{
RequestID: requestID,
ClientRequestID: clientRequestID,
AccountID: accountID,
Platform: platform,
Model: modelName,
RequestPath: func() string {
if c.Request != nil && c.Request.URL != nil {
return c.Request.URL.Path
}
return ""
}(),
Stream: stream,
UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: "upstream",
ErrorType: "upstream_error",
// Severity/retryability should reflect the upstream failure, not the final client status (200).
Severity: classifyOpsSeverity("upstream_error", effectiveUpstreamStatus),
StatusCode: status,
IsBusinessLimited: false,
IsCountTokens: isCountTokensRequest(c),
ErrorMessage: recoveredMsg,
ErrorBody: "",
ErrorSource: "upstream_http",
ErrorOwner: "provider",
UpstreamStatusCode: upstreamStatusCode,
UpstreamErrorMessage: upstreamErrorMessage,
UpstreamErrorDetail: upstreamErrorDetail,
UpstreamErrors: events,
IsRetryable: classifyOpsIsRetryable("upstream_error", effectiveUpstreamStatus),
RetryCount: 0,
CreatedAt: time.Now(),
}
if apiKey != nil {
entry.APIKeyID = &apiKey.ID
if apiKey.User != nil {
entry.UserID = &apiKey.User.ID
}
if apiKey.GroupID != nil {
entry.GroupID = apiKey.GroupID
}
// Prefer group platform if present (more stable than inferring from path).
if apiKey.Group != nil && apiKey.Group.Platform != "" {
entry.Platform = apiKey.Group.Platform
}
}
var clientIP string
if ip := strings.TrimSpace(ip.GetClientIP(c)); ip != "" {
clientIP = ip
entry.ClientIP = &clientIP
}
var requestBody []byte
if v, ok := c.Get(opsRequestBodyKey); ok {
if b, ok := v.([]byte); ok && len(b) > 0 {
requestBody = b
}
}
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
enqueueOpsErrorLog(ops, entry, requestBody)
return
}
body := w.buf.Bytes()
parsed := parseOpsErrorResponse(body)
// Skip logging if the error should be filtered based on settings
if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
return
}
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
model, _ := c.Get(opsModelKey)
streamV, _ := c.Get(opsStreamKey)
accountIDV, _ := c.Get(opsAccountIDKey)
var modelName string
if s, ok := model.(string); ok {
modelName = s
}
stream := false
if b, ok := streamV.(bool); ok {
stream = b
}
var accountID *int64
if v, ok := accountIDV.(int64); ok && v > 0 {
accountID = &v
}
fallbackPlatform := guessPlatformFromPath(c.Request.URL.Path)
platform := resolveOpsPlatform(apiKey, fallbackPlatform)
requestID := c.Writer.Header().Get("X-Request-Id")
if requestID == "" {
requestID = c.Writer.Header().Get("x-request-id")
}
phase := classifyOpsPhase(parsed.ErrorType, parsed.Message, parsed.Code)
isBusinessLimited := classifyOpsIsBusinessLimited(parsed.ErrorType, phase, parsed.Code, status, parsed.Message)
errorOwner := classifyOpsErrorOwner(phase, parsed.Message)
errorSource := classifyOpsErrorSource(phase, parsed.Message)
entry := &service.OpsInsertErrorLogInput{
RequestID: requestID,
ClientRequestID: clientRequestID,
AccountID: accountID,
Platform: platform,
Model: modelName,
RequestPath: func() string {
if c.Request != nil && c.Request.URL != nil {
return c.Request.URL.Path
}
return ""
}(),
Stream: stream,
UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: phase,
ErrorType: normalizeOpsErrorType(parsed.ErrorType, parsed.Code),
Severity: classifyOpsSeverity(parsed.ErrorType, status),
StatusCode: status,
IsBusinessLimited: isBusinessLimited,
IsCountTokens: isCountTokensRequest(c),
ErrorMessage: parsed.Message,
// Keep the full captured error body (capture is already capped at 64KB) so the
// service layer can sanitize JSON before truncating for storage.
ErrorBody: string(body),
ErrorSource: errorSource,
ErrorOwner: errorOwner,
IsRetryable: classifyOpsIsRetryable(parsed.ErrorType, status),
RetryCount: 0,
CreatedAt: time.Now(),
}
// Capture upstream error context set by gateway services (if present).
// This does NOT affect the client response; it enriches Ops troubleshooting data.
{
if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok {
switch t := v.(type) {
case int:
if t > 0 {
code := t
entry.UpstreamStatusCode = &code
}
case int64:
if t > 0 {
code := int(t)
entry.UpstreamStatusCode = &code
}
}
}
if v, ok := c.Get(service.OpsUpstreamErrorMessageKey); ok {
if s, ok := v.(string); ok {
if msg := strings.TrimSpace(s); msg != "" {
entry.UpstreamErrorMessage = &msg
}
}
}
if v, ok := c.Get(service.OpsUpstreamErrorDetailKey); ok {
if s, ok := v.(string); ok {
if detail := strings.TrimSpace(s); detail != "" {
entry.UpstreamErrorDetail = &detail
}
}
}
if v, ok := c.Get(service.OpsUpstreamErrorsKey); ok {
if events, ok := v.([]*service.OpsUpstreamErrorEvent); ok && len(events) > 0 {
entry.UpstreamErrors = events
// Best-effort backfill the single upstream fields from the last event when missing.
last := events[len(events)-1]
if last != nil {
if entry.UpstreamStatusCode == nil && last.UpstreamStatusCode > 0 {
code := last.UpstreamStatusCode
entry.UpstreamStatusCode = &code
}
if entry.UpstreamErrorMessage == nil && strings.TrimSpace(last.Message) != "" {
msg := strings.TrimSpace(last.Message)
entry.UpstreamErrorMessage = &msg
}
if entry.UpstreamErrorDetail == nil && strings.TrimSpace(last.Detail) != "" {
detail := strings.TrimSpace(last.Detail)
entry.UpstreamErrorDetail = &detail
}
}
}
}
}
if apiKey != nil {
entry.APIKeyID = &apiKey.ID
if apiKey.User != nil {
entry.UserID = &apiKey.User.ID
}
if apiKey.GroupID != nil {
entry.GroupID = apiKey.GroupID
}
// Prefer group platform if present (more stable than inferring from path).
if apiKey.Group != nil && apiKey.Group.Platform != "" {
entry.Platform = apiKey.Group.Platform
}
}
var clientIP string
if ip := strings.TrimSpace(ip.GetClientIP(c)); ip != "" {
clientIP = ip
entry.ClientIP = &clientIP
}
var requestBody []byte
if v, ok := c.Get(opsRequestBodyKey); ok {
if b, ok := v.([]byte); ok && len(b) > 0 {
requestBody = b
}
}
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
// Do NOT store Authorization/Cookie/etc.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
enqueueOpsErrorLog(ops, entry, requestBody)
}
}
var opsRetryRequestHeaderAllowlist = []string{
"anthropic-beta",
"anthropic-version",
}
// isCountTokensRequest checks if the request is a count_tokens request
func isCountTokensRequest(c *gin.Context) bool {
if c == nil || c.Request == nil || c.Request.URL == nil {
return false
}
return strings.Contains(c.Request.URL.Path, "/count_tokens")
}
func extractOpsRetryRequestHeaders(c *gin.Context) *string {
if c == nil || c.Request == nil {
return nil
}
headers := make(map[string]string, 4)
for _, key := range opsRetryRequestHeaderAllowlist {
v := strings.TrimSpace(c.GetHeader(key))
if v == "" {
continue
}
// Keep headers small even if a client sends something unexpected.
headers[key] = truncateString(v, 512)
}
if len(headers) == 0 {
return nil
}
raw, err := json.Marshal(headers)
if err != nil {
return nil
}
s := string(raw)
return &s
}
type parsedOpsError struct {
ErrorType string
Message string
Code string
}
func parseOpsErrorResponse(body []byte) parsedOpsError {
if len(body) == 0 {
return parsedOpsError{}
}
// Fast path: attempt to decode into a generic map.
var m map[string]any
if err := json.Unmarshal(body, &m); err != nil {
return parsedOpsError{Message: truncateString(string(body), 1024)}
}
// Claude/OpenAI-style gateway error: { type:"error", error:{ type, message } }
if errObj, ok := m["error"].(map[string]any); ok {
t, _ := errObj["type"].(string)
msg, _ := errObj["message"].(string)
// Gemini googleError also uses "error": { code, message, status }
if msg == "" {
if v, ok := errObj["message"]; ok {
msg, _ = v.(string)
}
}
if t == "" {
// Gemini error does not have "type" field.
t = "api_error"
}
// For gemini error, capture numeric code as string for business-limited mapping if needed.
var code string
if v, ok := errObj["code"]; ok {
switch n := v.(type) {
case float64:
code = strconvItoa(int(n))
case int:
code = strconvItoa(n)
}
}
return parsedOpsError{ErrorType: t, Message: msg, Code: code}
}
// APIKeyAuth-style: { code:"INSUFFICIENT_BALANCE", message:"..." }
code, _ := m["code"].(string)
msg, _ := m["message"].(string)
if code != "" || msg != "" {
return parsedOpsError{ErrorType: "api_error", Message: msg, Code: code}
}
return parsedOpsError{Message: truncateString(string(body), 1024)}
}
func resolveOpsPlatform(apiKey *service.APIKey, fallback string) string {
if apiKey != nil && apiKey.Group != nil && apiKey.Group.Platform != "" {
return apiKey.Group.Platform
}
return fallback
}
func guessPlatformFromPath(path string) string {
p := strings.ToLower(path)
switch {
case strings.HasPrefix(p, "/antigravity/"):
return service.PlatformAntigravity
case strings.HasPrefix(p, "/v1beta/"):
return service.PlatformGemini
case strings.Contains(p, "/responses"):
return service.PlatformOpenAI
default:
return ""
}
}
func normalizeOpsErrorType(errType string, code string) string {
if errType != "" {
return errType
}
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE":
return "billing_error"
case "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
return "subscription_error"
default:
return "api_error"
}
}
func classifyOpsPhase(errType, message, code string) string {
msg := strings.ToLower(message)
// Standardized phases: request|auth|routing|upstream|network|internal
// Map billing/concurrency/response => request; scheduling => routing.
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
return "request"
}
switch errType {
case "authentication_error":
return "auth"
case "billing_error", "subscription_error":
return "request"
case "rate_limit_error":
if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") || strings.Contains(msg, "queue") {
return "request"
}
return "upstream"
case "invalid_request_error":
return "request"
case "upstream_error", "overloaded_error":
return "upstream"
case "api_error":
if strings.Contains(msg, "no available accounts") {
return "routing"
}
return "internal"
default:
return "internal"
}
}
func classifyOpsSeverity(errType string, status int) string {
switch errType {
case "invalid_request_error", "authentication_error", "billing_error", "subscription_error":
return "P3"
}
if status >= 500 {
return "P1"
}
if status == 429 {
return "P1"
}
if status >= 400 {
return "P2"
}
return "P3"
}
func classifyOpsIsRetryable(errType string, statusCode int) bool {
switch errType {
case "authentication_error", "invalid_request_error":
return false
case "timeout_error":
return true
case "rate_limit_error":
// May be transient (upstream or queue); retry can help.
return true
case "billing_error", "subscription_error":
return false
case "upstream_error", "overloaded_error":
return statusCode >= 500 || statusCode == 429 || statusCode == 529
default:
return statusCode >= 500
}
}
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
return true
}
if phase == "billing" || phase == "concurrency" {
// SLA/错误率排除“用户级业务限制”
return true
}
// Avoid treating upstream rate limits as business-limited.
if errType == "rate_limit_error" && strings.Contains(strings.ToLower(message), "upstream") {
return false
}
_ = status
return false
}
func classifyOpsErrorOwner(phase string, message string) string {
// Standardized owners: client|provider|platform
switch phase {
case "upstream", "network":
return "provider"
case "request", "auth":
return "client"
case "routing", "internal":
return "platform"
default:
if strings.Contains(strings.ToLower(message), "upstream") {
return "provider"
}
return "platform"
}
}
func classifyOpsErrorSource(phase string, message string) string {
// Standardized sources: client_request|upstream_http|gateway
switch phase {
case "upstream":
return "upstream_http"
case "network":
return "gateway"
case "request", "auth":
return "client_request"
case "routing", "internal":
return "gateway"
default:
if strings.Contains(strings.ToLower(message), "upstream") {
return "upstream_http"
}
return "gateway"
}
}
func truncateString(s string, max int) string {
if max <= 0 {
return ""
}
if len(s) <= max {
return s
}
cut := s[:max]
// Ensure truncation does not split multi-byte characters.
for len(cut) > 0 && !utf8.ValidString(cut) {
cut = cut[:len(cut)-1]
}
return cut
}
func strconvItoa(v int) string {
return strconv.Itoa(v)
}
// shouldSkipOpsErrorLog determines if an error should be skipped from logging based on settings.
// Returns true for errors that should be filtered according to OpsAdvancedSettings.
func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message, body, requestPath string) bool {
if ops == nil {
return false
}
// Get advanced settings to check filter configuration
settings, err := ops.GetOpsAdvancedSettings(ctx)
if err != nil || settings == nil {
// If we can't get settings, don't skip (fail open)
return false
}
msgLower := strings.ToLower(message)
bodyLower := strings.ToLower(body)
// Check if count_tokens errors should be ignored
if settings.IgnoreCountTokensErrors && strings.Contains(requestPath, "/count_tokens") {
return true
}
// Check if context canceled errors should be ignored (client disconnects)
if settings.IgnoreContextCanceled {
if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
return true
}
}
// Check if "no available accounts" errors should be ignored
if settings.IgnoreNoAvailableAccounts {
if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
return true
}
}
return false
}
...@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
APIBaseURL: settings.APIBaseURL, APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo, ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL, DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: h.version, Version: h.version,
}) })
......
...@@ -19,7 +19,9 @@ func ProvideAdminHandlers( ...@@ -19,7 +19,9 @@ func ProvideAdminHandlers(
antigravityOAuthHandler *admin.AntigravityOAuthHandler, antigravityOAuthHandler *admin.AntigravityOAuthHandler,
proxyHandler *admin.ProxyHandler, proxyHandler *admin.ProxyHandler,
redeemHandler *admin.RedeemHandler, redeemHandler *admin.RedeemHandler,
promoHandler *admin.PromoHandler,
settingHandler *admin.SettingHandler, settingHandler *admin.SettingHandler,
opsHandler *admin.OpsHandler,
systemHandler *admin.SystemHandler, systemHandler *admin.SystemHandler,
subscriptionHandler *admin.SubscriptionHandler, subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler, usageHandler *admin.UsageHandler,
...@@ -36,7 +38,9 @@ func ProvideAdminHandlers( ...@@ -36,7 +38,9 @@ func ProvideAdminHandlers(
AntigravityOAuth: antigravityOAuthHandler, AntigravityOAuth: antigravityOAuthHandler,
Proxy: proxyHandler, Proxy: proxyHandler,
Redeem: redeemHandler, Redeem: redeemHandler,
Promo: promoHandler,
Setting: settingHandler, Setting: settingHandler,
Ops: opsHandler,
System: systemHandler, System: systemHandler,
Subscription: subscriptionHandler, Subscription: subscriptionHandler,
Usage: usageHandler, Usage: usageHandler,
...@@ -105,7 +109,9 @@ var ProviderSet = wire.NewSet( ...@@ -105,7 +109,9 @@ var ProviderSet = wire.NewSet(
admin.NewAntigravityOAuthHandler, admin.NewAntigravityOAuthHandler,
admin.NewProxyHandler, admin.NewProxyHandler,
admin.NewRedeemHandler, admin.NewRedeemHandler,
admin.NewPromoHandler,
admin.NewSettingHandler, admin.NewSettingHandler,
admin.NewOpsHandler,
ProvideSystemHandler, ProvideSystemHandler,
admin.NewSubscriptionHandler, admin.NewSubscriptionHandler,
admin.NewUsageHandler, admin.NewUsageHandler,
......
package middleware
import (
"context"
"fmt"
"log"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RateLimitFailureMode Redis 故障策略
type RateLimitFailureMode int
const (
RateLimitFailOpen RateLimitFailureMode = iota
RateLimitFailClose
)
// RateLimitOptions 限流可选配置
type RateLimitOptions struct {
FailureMode RateLimitFailureMode
}
var rateLimitScript = redis.NewScript(`
local current = redis.call('INCR', KEYS[1])
local ttl = redis.call('PTTL', KEYS[1])
local repaired = 0
if current == 1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
elseif ttl == -1 then
redis.call('PEXPIRE', KEYS[1], ARGV[1])
repaired = 1
end
return {current, repaired}
`)
// rateLimitRun 允许测试覆写脚本执行逻辑
var rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
values, err := rateLimitScript.Run(ctx, client, []string{key}, windowMillis).Slice()
if err != nil {
return 0, false, err
}
if len(values) < 2 {
return 0, false, fmt.Errorf("rate limit script returned %d values", len(values))
}
count, err := parseInt64(values[0])
if err != nil {
return 0, false, err
}
repaired, err := parseInt64(values[1])
if err != nil {
return 0, false, err
}
return count, repaired == 1, nil
}
// RateLimiter Redis 速率限制器
type RateLimiter struct {
redis *redis.Client
prefix string
}
// NewRateLimiter 创建速率限制器实例
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
return &RateLimiter{
redis: redisClient,
prefix: "rate_limit:",
}
}
// Limit 返回速率限制中间件
// key: 限制类型标识
// limit: 时间窗口内最大请求数
// window: 时间窗口
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
return r.LimitWithOptions(key, limit, window, RateLimitOptions{})
}
// LimitWithOptions 返回速率限制中间件(带可选配置)
func (r *RateLimiter) LimitWithOptions(key string, limit int, window time.Duration, opts RateLimitOptions) gin.HandlerFunc {
failureMode := opts.FailureMode
if failureMode != RateLimitFailClose {
failureMode = RateLimitFailOpen
}
return func(c *gin.Context) {
ip := c.ClientIP()
redisKey := r.prefix + key + ":" + ip
ctx := c.Request.Context()
windowMillis := windowTTLMillis(window)
// 使用 Lua 脚本原子操作增加计数并设置过期
count, repaired, err := rateLimitRun(ctx, r.redis, redisKey, windowMillis)
if err != nil {
log.Printf("[RateLimit] redis error: key=%s mode=%s err=%v", redisKey, failureModeLabel(failureMode), err)
if failureMode == RateLimitFailClose {
abortRateLimit(c)
return
}
// Redis 错误时放行,避免影响正常服务
c.Next()
return
}
if repaired {
log.Printf("[RateLimit] ttl repaired: key=%s window_ms=%d", redisKey, windowMillis)
}
// 超过限制
if count > int64(limit) {
abortRateLimit(c)
return
}
c.Next()
}
}
func windowTTLMillis(window time.Duration) int64 {
ttl := window.Milliseconds()
if ttl < 1 {
return 1
}
return ttl
}
func abortRateLimit(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
"message": "Too many requests, please try again later",
})
}
func failureModeLabel(mode RateLimitFailureMode) string {
if mode == RateLimitFailClose {
return "fail-close"
}
return "fail-open"
}
func parseInt64(value any) (int64, error) {
switch v := value.(type) {
case int64:
return v, nil
case int:
return int64(v), nil
case string:
parsed, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, err
}
return parsed, nil
default:
return 0, fmt.Errorf("unexpected value type %T", value)
}
}
//go:build integration
package middleware
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
)
const redisImageTag = "redis:8.4-alpine"
func TestRateLimiterSetsTTLAndDoesNotRefresh(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx := context.Background()
rdb := startRedis(t, ctx)
limiter := NewRateLimiter(rdb)
router := gin.New()
router.Use(limiter.Limit("ttl-test", 10, 2*time.Second))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
recorder := performRequest(router)
require.Equal(t, http.StatusOK, recorder.Code)
redisKey := limiter.prefix + "ttl-test:127.0.0.1"
ttlBefore, err := rdb.PTTL(ctx, redisKey).Result()
require.NoError(t, err)
require.Greater(t, ttlBefore, time.Duration(0))
require.LessOrEqual(t, ttlBefore, 2*time.Second)
time.Sleep(50 * time.Millisecond)
recorder = performRequest(router)
require.Equal(t, http.StatusOK, recorder.Code)
ttlAfter, err := rdb.PTTL(ctx, redisKey).Result()
require.NoError(t, err)
require.Less(t, ttlAfter, ttlBefore)
}
func TestRateLimiterFixesMissingTTL(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx := context.Background()
rdb := startRedis(t, ctx)
limiter := NewRateLimiter(rdb)
router := gin.New()
router.Use(limiter.Limit("ttl-missing", 10, 2*time.Second))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
redisKey := limiter.prefix + "ttl-missing:127.0.0.1"
require.NoError(t, rdb.Set(ctx, redisKey, 5, 0).Err())
ttlBefore, err := rdb.PTTL(ctx, redisKey).Result()
require.NoError(t, err)
require.Less(t, ttlBefore, time.Duration(0))
recorder := performRequest(router)
require.Equal(t, http.StatusOK, recorder.Code)
ttlAfter, err := rdb.PTTL(ctx, redisKey).Result()
require.NoError(t, err)
require.Greater(t, ttlAfter, time.Duration(0))
}
func performRequest(router *gin.Engine) *httptest.ResponseRecorder {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
return recorder
}
func startRedis(t *testing.T, ctx context.Context) *redis.Client {
t.Helper()
redisContainer, err := tcredis.Run(ctx, redisImageTag)
require.NoError(t, err)
t.Cleanup(func() {
_ = redisContainer.Terminate(ctx)
})
redisHost, err := redisContainer.Host(ctx)
require.NoError(t, err)
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
require.NoError(t, err)
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
DB: 0,
})
require.NoError(t, rdb.Ping(ctx).Err())
t.Cleanup(func() {
_ = rdb.Close()
})
return rdb
}
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func TestWindowTTLMillis(t *testing.T) {
require.Equal(t, int64(1), windowTTLMillis(500*time.Microsecond))
require.Equal(t, int64(1), windowTTLMillis(1500*time.Microsecond))
require.Equal(t, int64(2), windowTTLMillis(2500*time.Microsecond))
}
func TestRateLimiterFailureModes(t *testing.T) {
gin.SetMode(gin.TestMode)
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
limiter := NewRateLimiter(rdb)
failOpenRouter := gin.New()
failOpenRouter.Use(limiter.Limit("test", 1, time.Second))
failOpenRouter.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder := httptest.NewRecorder()
failOpenRouter.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
failCloseRouter := gin.New()
failCloseRouter.Use(limiter.LimitWithOptions("test", 1, time.Second, RateLimitOptions{
FailureMode: RateLimitFailClose,
}))
failCloseRouter.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req = httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder = httptest.NewRecorder()
failCloseRouter.ServeHTTP(recorder, req)
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
}
func TestRateLimiterSuccessAndLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
originalRun := rateLimitRun
counts := []int64{1, 2}
callIndex := 0
rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) {
if callIndex >= len(counts) {
return counts[len(counts)-1], false, nil
}
value := counts[callIndex]
callIndex++
return value, false, nil
}
t.Cleanup(func() {
rateLimitRun = originalRun
})
limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"}))
router := gin.New()
router.Use(limiter.Limit("test", 1, time.Second))
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
req = httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "127.0.0.1:1234"
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
}
...@@ -7,6 +7,15 @@ type Key string ...@@ -7,6 +7,15 @@ type Key string
const ( const (
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置 // ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform Key = "ctx_force_platform" ForcePlatform Key = "ctx_force_platform"
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
// ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。
ClientRequestID Key = "ctx_client_request_id"
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
RetryCount Key = "ctx_retry_count"
// IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端
IsClaudeCodeClient Key = "ctx_is_claude_code_client" IsClaudeCodeClient Key = "ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置
Group Key = "ctx_group"
) )
// Package ip 提供客户端 IP 地址提取工具。
package ip
import (
"net"
"strings"
"github.com/gin-gonic/gin"
)
// GetClientIP 从 Gin Context 中提取客户端真实 IP 地址。
// 按以下优先级检查 Header:
// 1. CF-Connecting-IP (Cloudflare)
// 2. X-Real-IP (Nginx)
// 3. X-Forwarded-For (取第一个非私有 IP)
// 4. c.ClientIP() (Gin 内置方法)
func GetClientIP(c *gin.Context) string {
// 1. Cloudflare
if ip := c.GetHeader("CF-Connecting-IP"); ip != "" {
return normalizeIP(ip)
}
// 2. Nginx X-Real-IP
if ip := c.GetHeader("X-Real-IP"); ip != "" {
return normalizeIP(ip)
}
// 3. X-Forwarded-For (多个 IP 时取第一个公网 IP)
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
if ip != "" && !isPrivateIP(ip) {
return normalizeIP(ip)
}
}
// 如果都是私有 IP,返回第一个
if len(ips) > 0 {
return normalizeIP(strings.TrimSpace(ips[0]))
}
}
// 4. Gin 内置方法
return normalizeIP(c.ClientIP())
}
// normalizeIP 规范化 IP 地址,去除端口号和空格。
func normalizeIP(ip string) string {
ip = strings.TrimSpace(ip)
// 移除端口号(如 "192.168.1.1:8080" -> "192.168.1.1")
if host, _, err := net.SplitHostPort(ip); err == nil {
return host
}
return ip
}
// isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
// 私有 IP 范围
privateBlocks := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
}
for _, block := range privateBlocks {
_, cidr, err := net.ParseCIDR(block)
if err != nil {
continue
}
if cidr.Contains(ip) {
return true
}
}
return false
}
// MatchesPattern 检查 IP 是否匹配指定的模式(支持单个 IP 或 CIDR)。
// pattern 可以是:
// - 单个 IP: "192.168.1.100"
// - CIDR 范围: "192.168.1.0/24"
func MatchesPattern(clientIP, pattern string) bool {
ip := net.ParseIP(clientIP)
if ip == nil {
return false
}
// 尝试解析为 CIDR
if strings.Contains(pattern, "/") {
_, cidr, err := net.ParseCIDR(pattern)
if err != nil {
return false
}
return cidr.Contains(ip)
}
// 作为单个 IP 处理
patternIP := net.ParseIP(pattern)
if patternIP == nil {
return false
}
return ip.Equal(patternIP)
}
// MatchesAnyPattern 检查 IP 是否匹配任意一个模式。
func MatchesAnyPattern(clientIP string, patterns []string) bool {
for _, pattern := range patterns {
if MatchesPattern(clientIP, pattern) {
return true
}
}
return false
}
// CheckIPRestriction 检查 IP 是否被 API Key 的 IP 限制允许。
// 返回值:(是否允许, 拒绝原因)
// 逻辑:
// 1. 先检查黑名单,如果在黑名单中则直接拒绝
// 2. 如果白名单不为空,IP 必须在白名单中
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
// 规范化 IP
clientIP = normalizeIP(clientIP)
if clientIP == "" {
return false, "access denied"
}
// 1. 检查黑名单
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
return false, "access denied"
}
// 2. 检查白名单(如果设置了白名单,IP 必须在其中)
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
return false, "access denied"
}
return true, ""
}
// ValidateIPPattern 验证 IP 或 CIDR 格式是否有效。
func ValidateIPPattern(pattern string) bool {
if strings.Contains(pattern, "/") {
_, _, err := net.ParseCIDR(pattern)
return err == nil
}
return net.ParseIP(pattern) != nil
}
// ValidateIPPatterns 验证多个 IP 或 CIDR 格式。
// 返回无效的模式列表。
func ValidateIPPatterns(patterns []string) []string {
var invalid []string
for _, p := range patterns {
if !ValidateIPPattern(p) {
invalid = append(invalid, p)
}
}
return invalid
}
package usagestats package usagestats
// AccountStats 账号使用统计 // AccountStats 账号使用统计
//
// cost: 账号口径费用(使用 total_cost * account_rate_multiplier)
// standard_cost: 标准费用(使用 total_cost,不含倍率)
// user_cost: 用户/API Key 口径费用(使用 actual_cost,受分组倍率影响)
type AccountStats struct { type AccountStats struct {
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"` Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"`
StandardCost float64 `json:"standard_cost"`
UserCost float64 `json:"user_cost"`
} }
...@@ -9,6 +9,12 @@ type DashboardStats struct { ...@@ -9,6 +9,12 @@ type DashboardStats struct {
TotalUsers int64 `json:"total_users"` TotalUsers int64 `json:"total_users"`
TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数 TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数 ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// 小时活跃用户数(UTC 当前小时)
HourlyActiveUsers int64 `json:"hourly_active_users"`
// 预聚合新鲜度
StatsUpdatedAt string `json:"stats_updated_at"`
StatsStale bool `json:"stats_stale"`
// API Key 统计 // API Key 统计
TotalAPIKeys int64 `json:"total_api_keys"` TotalAPIKeys int64 `json:"total_api_keys"`
...@@ -141,14 +147,15 @@ type UsageLogFilters struct { ...@@ -141,14 +147,15 @@ type UsageLogFilters struct {
// UsageStats represents usage statistics // UsageStats represents usage statistics
type UsageStats struct { type UsageStats struct {
TotalRequests int64 `json:"total_requests"` TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"` TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"` TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"` TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"` TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"` TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"` TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"` TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
AverageDurationMs float64 `json:"average_duration_ms"`
} }
// BatchUserUsageStats represents usage stats for a single user // BatchUserUsageStats represents usage stats for a single user
...@@ -171,25 +178,29 @@ type AccountUsageHistory struct { ...@@ -171,25 +178,29 @@ type AccountUsageHistory struct {
Label string `json:"label"` Label string `json:"label"`
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"` Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"` // 标准计费(total_cost)
ActualCost float64 `json:"actual_cost"` ActualCost float64 `json:"actual_cost"` // 账号口径费用(total_cost * account_rate_multiplier)
UserCost float64 `json:"user_cost"` // 用户口径费用(actual_cost,受分组倍率影响)
} }
// AccountUsageSummary represents summary statistics for an account // AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary struct { type AccountUsageSummary struct {
Days int `json:"days"` Days int `json:"days"`
ActualDaysUsed int `json:"actual_days_used"` ActualDaysUsed int `json:"actual_days_used"`
TotalCost float64 `json:"total_cost"` TotalCost float64 `json:"total_cost"` // 账号口径费用
TotalUserCost float64 `json:"total_user_cost"` // 用户口径费用
TotalStandardCost float64 `json:"total_standard_cost"` TotalStandardCost float64 `json:"total_standard_cost"`
TotalRequests int64 `json:"total_requests"` TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"` TotalTokens int64 `json:"total_tokens"`
AvgDailyCost float64 `json:"avg_daily_cost"` AvgDailyCost float64 `json:"avg_daily_cost"` // 账号口径日均
AvgDailyUserCost float64 `json:"avg_daily_user_cost"`
AvgDailyRequests float64 `json:"avg_daily_requests"` AvgDailyRequests float64 `json:"avg_daily_requests"`
AvgDailyTokens float64 `json:"avg_daily_tokens"` AvgDailyTokens float64 `json:"avg_daily_tokens"`
AvgDurationMs float64 `json:"avg_duration_ms"` AvgDurationMs float64 `json:"avg_duration_ms"`
Today *struct { Today *struct {
Date string `json:"date"` Date string `json:"date"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"` Tokens int64 `json:"tokens"`
} `json:"today"` } `json:"today"`
...@@ -197,6 +208,7 @@ type AccountUsageSummary struct { ...@@ -197,6 +208,7 @@ type AccountUsageSummary struct {
Date string `json:"date"` Date string `json:"date"`
Label string `json:"label"` Label string `json:"label"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
} `json:"highest_cost_day"` } `json:"highest_cost_day"`
HighestRequestDay *struct { HighestRequestDay *struct {
...@@ -204,6 +216,7 @@ type AccountUsageSummary struct { ...@@ -204,6 +216,7 @@ type AccountUsageSummary struct {
Label string `json:"label"` Label string `json:"label"`
Requests int64 `json:"requests"` Requests int64 `json:"requests"`
Cost float64 `json:"cost"` Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
} `json:"highest_request_day"` } `json:"highest_request_day"`
} }
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"log"
"strconv" "strconv"
"time" "time"
...@@ -79,6 +80,10 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account ...@@ -79,6 +80,10 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable). SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired) SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil { if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID) builder.SetProxyID(*account.ProxyID)
} }
...@@ -115,6 +120,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account ...@@ -115,6 +120,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
account.ID = created.ID account.ID = created.ID
account.CreatedAt = created.CreatedAt account.CreatedAt = created.CreatedAt
account.UpdatedAt = created.UpdatedAt account.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
}
return nil return nil
} }
...@@ -287,6 +295,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account ...@@ -287,6 +295,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable). SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired) SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil { if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID) builder.SetProxyID(*account.ProxyID)
} else { } else {
...@@ -341,10 +353,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account ...@@ -341,10 +353,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
return translatePersistenceError(err, service.ErrAccountNotFound, nil) return translatePersistenceError(err, service.ErrAccountNotFound, nil)
} }
account.UpdatedAt = updated.UpdatedAt account.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
return nil return nil
} }
func (r *accountRepository) Delete(ctx context.Context, id int64) error { func (r *accountRepository) Delete(ctx context.Context, id int64) error {
groupIDs, err := r.loadAccountGroupIDs(ctx, id)
if err != nil {
return err
}
// 使用事务保证账号与关联分组的删除原子性 // 使用事务保证账号与关联分组的删除原子性
tx, err := r.client.Tx(ctx) tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) { if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
...@@ -368,7 +387,12 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { ...@@ -368,7 +387,12 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
} }
if tx != nil { if tx != nil {
return tx.Commit() if err := tx.Commit(); err != nil {
return err
}
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
} }
return nil return nil
} }
...@@ -455,7 +479,18 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error ...@@ -455,7 +479,18 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
Where(dbaccount.IDEQ(id)). Where(dbaccount.IDEQ(id)).
SetLastUsedAt(now). SetLastUsedAt(now).
Save(ctx) Save(ctx)
return err if err != nil {
return err
}
payload := map[string]any{
"last_used": map[string]int64{
strconv.FormatInt(id, 10): now.Unix(),
},
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
}
return nil
} }
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
...@@ -479,7 +514,18 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map ...@@ -479,7 +514,18 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
args = append(args, pq.Array(ids)) args = append(args, pq.Array(ids))
_, err := r.sql.ExecContext(ctx, caseSQL, args...) _, err := r.sql.ExecContext(ctx, caseSQL, args...)
return err if err != nil {
return err
}
lastUsedPayload := make(map[string]int64, len(updates))
for id, ts := range updates {
lastUsedPayload[strconv.FormatInt(id, 10)] = ts.Unix()
}
payload := map[string]any{"last_used": lastUsedPayload}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
}
return nil
} }
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
...@@ -488,7 +534,13 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str ...@@ -488,7 +534,13 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
SetStatus(service.StatusError). SetStatus(service.StatusError).
SetErrorMessage(errorMsg). SetErrorMessage(errorMsg).
Save(ctx) Save(ctx)
return err if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
}
return nil
} }
func (r *accountRepository) ClearError(ctx context.Context, id int64) error { func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
...@@ -506,7 +558,14 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i ...@@ -506,7 +558,14 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
SetGroupID(groupID). SetGroupID(groupID).
SetPriority(priority). SetPriority(priority).
Save(ctx) Save(ctx)
return err if err != nil {
return err
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
} }
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error { func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
...@@ -516,7 +575,14 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou ...@@ -516,7 +575,14 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
dbaccountgroup.GroupIDEQ(groupID), dbaccountgroup.GroupIDEQ(groupID),
). ).
Exec(ctx) Exec(ctx)
return err if err != nil {
return err
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
} }
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) { func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
...@@ -537,6 +603,10 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s ...@@ -537,6 +603,10 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s
} }
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
existingGroupIDs, err := r.loadAccountGroupIDs(ctx, accountID)
if err != nil {
return err
}
// 使用事务保证删除旧绑定与创建新绑定的原子性 // 使用事务保证删除旧绑定与创建新绑定的原子性
tx, err := r.client.Tx(ctx) tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) { if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
...@@ -577,7 +647,13 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro ...@@ -577,7 +647,13 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
} }
if tx != nil { if tx != nil {
return tx.Commit() if err := tx.Commit(); err != nil {
return err
}
}
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
} }
return nil return nil
} }
...@@ -681,7 +757,13 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA ...@@ -681,7 +757,13 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
SetRateLimitedAt(now). SetRateLimitedAt(now).
SetRateLimitResetAt(resetAt). SetRateLimitResetAt(resetAt).
Save(ctx) Save(ctx)
return err if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
}
return nil
} }
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
...@@ -715,6 +797,49 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i ...@@ -715,6 +797,49 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
if affected == 0 { if affected == 0 {
return service.ErrAccountNotFound return service.ErrAccountNotFound
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
if scope == "" {
return nil
}
now := time.Now().UTC()
payload := map[string]string{
"rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
raw, err := json.Marshal(payload)
if err != nil {
return err
}
path := "{model_rate_limits," + scope + "}"
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
path,
raw,
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
}
return nil return nil
} }
...@@ -723,7 +848,13 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t ...@@ -723,7 +848,13 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
Where(dbaccount.IDEQ(id)). Where(dbaccount.IDEQ(id)).
SetOverloadUntil(until). SetOverloadUntil(until).
Save(ctx) Save(ctx)
return err if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
}
return nil
} }
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
...@@ -736,7 +867,13 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, ...@@ -736,7 +867,13 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
AND deleted_at IS NULL AND deleted_at IS NULL
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1) AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
`, until, reason, id) `, until, reason, id)
return err if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
}
return nil
} }
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error { func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
...@@ -748,7 +885,13 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64 ...@@ -748,7 +885,13 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
WHERE id = $1 WHERE id = $1
AND deleted_at IS NULL AND deleted_at IS NULL
`, id) `, id)
return err if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
}
return nil
} }
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error { func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
...@@ -758,7 +901,13 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error ...@@ -758,7 +901,13 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
ClearRateLimitResetAt(). ClearRateLimitResetAt().
ClearOverloadUntil(). ClearOverloadUntil().
Save(ctx) Save(ctx)
return err if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
}
return nil
} }
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
...@@ -779,6 +928,33 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id ...@@ -779,6 +928,33 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
if affected == 0 { if affected == 0 {
return service.ErrAccountNotFound return service.ErrAccountNotFound
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'model_rate_limits', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
}
return nil return nil
} }
...@@ -801,7 +977,13 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu ...@@ -801,7 +977,13 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
Where(dbaccount.IDEQ(id)). Where(dbaccount.IDEQ(id)).
SetSchedulable(schedulable). SetSchedulable(schedulable).
Save(ctx) Save(ctx)
return err if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
}
return nil
} }
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) { func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
...@@ -822,6 +1004,11 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti ...@@ -822,6 +1004,11 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
if err != nil { if err != nil {
return 0, err return 0, err
} }
if rows > 0 {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
}
}
return rows, nil return rows, nil
} }
...@@ -853,6 +1040,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m ...@@ -853,6 +1040,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
if affected == 0 { if affected == 0 {
return service.ErrAccountNotFound return service.ErrAccountNotFound
} }
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
}
return nil return nil
} }
...@@ -890,6 +1080,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -890,6 +1080,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.Priority) args = append(args, *updates.Priority)
idx++ idx++
} }
if updates.RateMultiplier != nil {
setClauses = append(setClauses, "rate_multiplier = $"+itoa(idx))
args = append(args, *updates.RateMultiplier)
idx++
}
if updates.Status != nil { if updates.Status != nil {
setClauses = append(setClauses, "status = $"+itoa(idx)) setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status) args = append(args, *updates.Status)
...@@ -937,6 +1132,12 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates ...@@ -937,6 +1132,12 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if err != nil { if err != nil {
return 0, err return 0, err
} }
if rows > 0 {
payload := map[string]any{"account_ids": ids}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
}
}
return rows, nil return rows, nil
} }
...@@ -1179,11 +1380,61 @@ func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs [] ...@@ -1179,11 +1380,61 @@ func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
} }
func (r *accountRepository) loadAccountGroupIDs(ctx context.Context, accountID int64) ([]int64, error) {
entries, err := r.client.AccountGroup.
Query().
Where(dbaccountgroup.AccountIDEQ(accountID)).
All(ctx)
if err != nil {
return nil, err
}
ids := make([]int64, 0, len(entries))
for _, entry := range entries {
ids = append(ids, entry.GroupID)
}
return ids, nil
}
func mergeGroupIDs(a []int64, b []int64) []int64 {
seen := make(map[int64]struct{}, len(a)+len(b))
out := make([]int64, 0, len(a)+len(b))
for _, id := range a {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
for _, id := range b {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
return out
}
func buildSchedulerGroupPayload(groupIDs []int64) map[string]any {
if len(groupIDs) == 0 {
return nil
}
return map[string]any{"group_ids": groupIDs}
}
func accountEntityToService(m *dbent.Account) *service.Account { func accountEntityToService(m *dbent.Account) *service.Account {
if m == nil { if m == nil {
return nil return nil
} }
rateMultiplier := m.RateMultiplier
return &service.Account{ return &service.Account{
ID: m.ID, ID: m.ID,
Name: m.Name, Name: m.Name,
...@@ -1195,6 +1446,7 @@ func accountEntityToService(m *dbent.Account) *service.Account { ...@@ -1195,6 +1446,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
ProxyID: m.ProxyID, ProxyID: m.ProxyID,
Concurrency: m.Concurrency, Concurrency: m.Concurrency,
Priority: m.Priority, Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status, Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage), ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt, LastUsedAt: m.LastUsedAt,
......
...@@ -2,6 +2,7 @@ package repository ...@@ -2,6 +2,7 @@ package repository
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"time" "time"
...@@ -13,6 +14,7 @@ import ( ...@@ -13,6 +14,7 @@ import (
const ( const (
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:" apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour apiKeyRateLimitDuration = 24 * time.Hour
apiKeyAuthCachePrefix = "apikey:auth:"
) )
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting. // apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
...@@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string { ...@@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
} }
func apiKeyAuthCacheKey(key string) string {
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
}
type apiKeyCache struct { type apiKeyCache struct {
rdb *redis.Client rdb *redis.Client
} }
...@@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er ...@@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return c.rdb.Expire(ctx, apiKey, ttl).Err() return c.rdb.Expire(ctx, apiKey, ttl).Err()
} }
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
if err != nil {
return nil, err
}
var entry service.APIKeyAuthCacheEntry
if err := json.Unmarshal(val, &entry); err != nil {
return nil, err
}
return &entry, nil
}
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
if entry == nil {
return nil
}
payload, err := json.Marshal(entry)
if err != nil {
return err
}
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
}
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
}
...@@ -6,7 +6,9 @@ import ( ...@@ -6,7 +6,9 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins" "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
...@@ -26,13 +28,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery { ...@@ -26,13 +28,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
} }
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error { func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
created, err := r.client.APIKey.Create(). builder := r.client.APIKey.Create().
SetUserID(key.UserID). SetUserID(key.UserID).
SetKey(key.Key). SetKey(key.Key).
SetName(key.Name). SetName(key.Name).
SetStatus(key.Status). SetStatus(key.Status).
SetNillableGroupID(key.GroupID). SetNillableGroupID(key.GroupID)
Save(ctx)
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
}
created, err := builder.Save(ctx)
if err == nil { if err == nil {
key.ID = created.ID key.ID = created.ID
key.CreatedAt = created.CreatedAt key.CreatedAt = created.CreatedAt
...@@ -56,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK ...@@ -56,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
return apiKeyEntityToService(m), nil return apiKeyEntityToService(m), nil
} }
// GetOwnerID 根据 API Key ID 获取其所有者(用户)ID。 // GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者(用户)ID。
// 相比 GetByID,此方法性能更优,因为: // 相比 GetByID,此方法性能更优,因为:
// - 使用 Select() 只查询 user_id 字段,减少数据传输量 // - 使用 Select() 只查询必要字段,减少数据传输量
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等) // - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查) // - 适用于删除等只需 key 与用户 ID 的场景
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) { func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
m, err := r.activeQuery(). m, err := r.activeQuery().
Where(apikey.IDEQ(id)). Where(apikey.IDEQ(id)).
Select(apikey.FieldUserID). Select(apikey.FieldKey, apikey.FieldUserID).
Only(ctx) Only(ctx)
if err != nil { if err != nil {
if dbent.IsNotFound(err) { if dbent.IsNotFound(err) {
return 0, service.ErrAPIKeyNotFound return "", 0, service.ErrAPIKeyNotFound
} }
return 0, err return "", 0, err
} }
return m.UserID, nil return m.Key, m.UserID, nil
} }
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
...@@ -90,6 +100,56 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A ...@@ -90,6 +100,56 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
return apiKeyEntityToService(m), nil return apiKeyEntityToService(m), nil
} }
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
Select(
apikey.FieldID,
apikey.FieldUserID,
apikey.FieldGroupID,
apikey.FieldStatus,
apikey.FieldIPWhitelist,
apikey.FieldIPBlacklist,
).
WithUser(func(q *dbent.UserQuery) {
q.Select(
user.FieldID,
user.FieldStatus,
user.FieldRole,
user.FieldBalance,
user.FieldConcurrency,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
q.Select(
group.FieldID,
group.FieldName,
group.FieldPlatform,
group.FieldStatus,
group.FieldSubscriptionType,
group.FieldRateMultiplier,
group.FieldDailyLimitUsd,
group.FieldWeeklyLimitUsd,
group.FieldMonthlyLimitUsd,
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
group.FieldModelRoutingEnabled,
group.FieldModelRouting,
)
}).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error { func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。 // 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除, // 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
...@@ -108,6 +168,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro ...@@ -108,6 +168,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearGroupID() builder.ClearGroupID()
} }
// IP 限制字段
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
} else {
builder.ClearIPWhitelist()
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
} else {
builder.ClearIPBlacklist()
}
affected, err := builder.Save(ctx) affected, err := builder.Save(ctx)
if err != nil { if err != nil {
return err return err
...@@ -263,19 +335,43 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i ...@@ -263,19 +335,43 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
return int64(count), err return int64(count), err
} }
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.UserIDEQ(userID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.GroupIDEQ(groupID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil { if m == nil {
return nil return nil
} }
out := &service.APIKey{ out := &service.APIKey{
ID: m.ID, ID: m.ID,
UserID: m.UserID, UserID: m.UserID,
Key: m.Key, Key: m.Key,
Name: m.Name, Name: m.Name,
Status: m.Status, Status: m.Status,
CreatedAt: m.CreatedAt, IPWhitelist: m.IPWhitelist,
UpdatedAt: m.UpdatedAt, IPBlacklist: m.IPBlacklist,
GroupID: m.GroupID, CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
} }
if m.Edges.User != nil { if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User) out.User = userEntityToService(m.Edges.User)
...@@ -317,6 +413,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -317,6 +413,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RateMultiplier: g.RateMultiplier, RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive, IsExclusive: g.IsExclusive,
Status: g.Status, Status: g.Status,
Hydrated: true,
SubscriptionType: g.SubscriptionType, SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd, DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd, WeeklyLimitUSD: g.WeeklyLimitUsd,
...@@ -327,6 +424,8 @@ func groupEntityToService(g *dbent.Group) *service.Group { ...@@ -327,6 +424,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DefaultValidityDays: g.DefaultValidityDays, DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly, ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID, FallbackGroupID: g.FallbackGroupID,
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
} }
......
...@@ -93,7 +93,7 @@ var ( ...@@ -93,7 +93,7 @@ var (
return redis.call('ZCARD', key) return redis.call('ZCARD', key)
`) `)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing // incrementWaitScript - refreshes TTL on each increment to keep queue depth accurate
// KEYS[1] = wait queue key // KEYS[1] = wait queue key
// ARGV[1] = maxWait // ARGV[1] = maxWait
// ARGV[2] = TTL in seconds // ARGV[2] = TTL in seconds
...@@ -111,15 +111,13 @@ var ( ...@@ -111,15 +111,13 @@ var (
local newVal = redis.call('INCR', KEYS[1]) local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data -- Refresh TTL so long-running traffic doesn't expire active queue counters.
if newVal == 1 then redis.call('EXPIRE', KEYS[1], ARGV[2])
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1 return 1
`) `)
// incrementAccountWaitScript - account-level wait queue count // incrementAccountWaitScript - account-level wait queue count (refresh TTL on each increment)
incrementAccountWaitScript = redis.NewScript(` incrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1]) local current = redis.call('GET', KEYS[1])
if current == false then if current == false then
...@@ -134,10 +132,8 @@ var ( ...@@ -134,10 +132,8 @@ var (
local newVal = redis.call('INCR', KEYS[1]) local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data -- Refresh TTL so long-running traffic doesn't expire active queue counters.
if newVal == 1 then redis.call('EXPIRE', KEYS[1], ARGV[2])
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1 return 1
`) `)
......
package repository
import (
"context"
"database/sql"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type dashboardAggregationRepository struct {
sql sqlExecutor
}
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
if sqlDB == nil {
return nil
}
if !isPostgresDriver(sqlDB) {
log.Printf("[DashboardAggregation] 检测到非 PostgreSQL 驱动,已自动禁用预聚合")
return nil
}
return newDashboardAggregationRepositoryWithSQL(sqlDB)
}
func newDashboardAggregationRepositoryWithSQL(sqlq sqlExecutor) *dashboardAggregationRepository {
return &dashboardAggregationRepository{sql: sqlq}
}
func isPostgresDriver(db *sql.DB) bool {
if db == nil {
return false
}
_, ok := db.Driver().(*pq.Driver)
return ok
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
if !endLocal.After(startLocal) {
return nil
}
hourStart := startLocal.Truncate(time.Hour)
hourEnd := endLocal.Truncate(time.Hour)
if endLocal.After(hourEnd) {
hourEnd = hourEnd.Add(time.Hour)
}
dayStart := truncateToDay(startLocal)
dayEnd := truncateToDay(endLocal)
if endLocal.After(dayEnd) {
dayEnd = dayEnd.Add(24 * time.Hour)
}
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
return err
}
return nil
}
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
var ts time.Time
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
if err := scanSingleRow(ctx, r.sql, query, nil, &ts); err != nil {
if err == sql.ErrNoRows {
return time.Unix(0, 0).UTC(), nil
}
return time.Time{}, err
}
return ts.UTC(), nil
}
func (r *dashboardAggregationRepository) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
query := `
INSERT INTO usage_dashboard_aggregation_watermark (id, last_aggregated_at, updated_at)
VALUES (1, $1, NOW())
ON CONFLICT (id)
DO UPDATE SET last_aggregated_at = EXCLUDED.last_aggregated_at, updated_at = EXCLUDED.updated_at
`
_, err := r.sql.ExecContext(ctx, query, aggregatedAt.UTC())
return err
}
func (r *dashboardAggregationRepository) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
hourlyCutoffUTC := hourlyCutoff.UTC()
dailyCutoffUTC := dailyCutoff.UTC()
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
return err
}
return nil
}
func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
if err != nil {
return err
}
if isPartitioned {
return r.dropUsageLogsPartitions(ctx, cutoff)
}
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
return err
}
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
if err != nil || !isPartitioned {
return err
}
monthStart := truncateToMonthUTC(now)
prevMonth := monthStart.AddDate(0, -1, 0)
nextMonth := monthStart.AddDate(0, 1, 0)
for _, m := range []time.Time{prevMonth, monthStart, nextMonth} {
if err := r.createUsageLogsPartition(ctx, m); err != nil {
return err
}
}
return nil
}
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
SELECT DISTINCT
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
user_id
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
ON CONFLICT DO NOTHING
`
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
SELECT DISTINCT
(bucket_start AT TIME ZONE $3)::date AS bucket_date,
user_id
FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2
ON CONFLICT DO NOTHING
`
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
WITH hourly AS (
SELECT
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
COUNT(*) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
COALESCE(SUM(total_cost), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) AS actual_cost,
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
GROUP BY 1
),
user_counts AS (
SELECT bucket_start, COUNT(*) AS active_users
FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY bucket_start
)
INSERT INTO usage_dashboard_hourly (
bucket_start,
total_requests,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
total_cost,
actual_cost,
total_duration_ms,
active_users,
computed_at
)
SELECT
hourly.bucket_start,
hourly.total_requests,
hourly.input_tokens,
hourly.output_tokens,
hourly.cache_creation_tokens,
hourly.cache_read_tokens,
hourly.total_cost,
hourly.actual_cost,
hourly.total_duration_ms,
COALESCE(user_counts.active_users, 0) AS active_users,
NOW()
FROM hourly
LEFT JOIN user_counts ON user_counts.bucket_start = hourly.bucket_start
ON CONFLICT (bucket_start)
DO UPDATE SET
total_requests = EXCLUDED.total_requests,
input_tokens = EXCLUDED.input_tokens,
output_tokens = EXCLUDED.output_tokens,
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
cache_read_tokens = EXCLUDED.cache_read_tokens,
total_cost = EXCLUDED.total_cost,
actual_cost = EXCLUDED.actual_cost,
total_duration_ms = EXCLUDED.total_duration_ms,
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
WITH daily AS (
SELECT
(bucket_start AT TIME ZONE $5)::date AS bucket_date,
COALESCE(SUM(total_requests), 0) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
COALESCE(SUM(total_cost), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) AS actual_cost,
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
FROM usage_dashboard_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY (bucket_start AT TIME ZONE $5)::date
),
user_counts AS (
SELECT bucket_date, COUNT(*) AS active_users
FROM usage_dashboard_daily_users
WHERE bucket_date >= $3::date AND bucket_date < $4::date
GROUP BY bucket_date
)
INSERT INTO usage_dashboard_daily (
bucket_date,
total_requests,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
total_cost,
actual_cost,
total_duration_ms,
active_users,
computed_at
)
SELECT
daily.bucket_date,
daily.total_requests,
daily.input_tokens,
daily.output_tokens,
daily.cache_creation_tokens,
daily.cache_read_tokens,
daily.total_cost,
daily.actual_cost,
daily.total_duration_ms,
COALESCE(user_counts.active_users, 0) AS active_users,
NOW()
FROM daily
LEFT JOIN user_counts ON user_counts.bucket_date = daily.bucket_date
ON CONFLICT (bucket_date)
DO UPDATE SET
total_requests = EXCLUDED.total_requests,
input_tokens = EXCLUDED.input_tokens,
output_tokens = EXCLUDED.output_tokens,
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
cache_read_tokens = EXCLUDED.cache_read_tokens,
total_cost = EXCLUDED.total_cost,
actual_cost = EXCLUDED.actual_cost,
total_duration_ms = EXCLUDED.total_duration_ms,
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) isUsageLogsPartitioned(ctx context.Context) (bool, error) {
query := `
SELECT EXISTS(
SELECT 1
FROM pg_partitioned_table pt
JOIN pg_class c ON c.oid = pt.partrelid
WHERE c.relname = 'usage_logs'
)
`
var partitioned bool
if err := scanSingleRow(ctx, r.sql, query, nil, &partitioned); err != nil {
return false, err
}
return partitioned, nil
}
func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Context, cutoff time.Time) error {
rows, err := r.sql.QueryContext(ctx, `
SELECT c.relname
FROM pg_inherits
JOIN pg_class c ON c.oid = pg_inherits.inhrelid
JOIN pg_class p ON p.oid = pg_inherits.inhparent
WHERE p.relname = 'usage_logs'
`)
if err != nil {
return err
}
defer func() {
_ = rows.Close()
}()
cutoffMonth := truncateToMonthUTC(cutoff)
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return err
}
if !strings.HasPrefix(name, "usage_logs_") {
continue
}
suffix := strings.TrimPrefix(name, "usage_logs_")
month, err := time.Parse("200601", suffix)
if err != nil {
continue
}
month = month.UTC()
if month.Before(cutoffMonth) {
if _, err := r.sql.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(name))); err != nil {
return err
}
}
}
return rows.Err()
}
func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Context, month time.Time) error {
monthStart := truncateToMonthUTC(month)
nextMonth := monthStart.AddDate(0, 1, 0)
name := fmt.Sprintf("usage_logs_%s", monthStart.Format("200601"))
query := fmt.Sprintf(
"CREATE TABLE IF NOT EXISTS %s PARTITION OF usage_logs FOR VALUES FROM (%s) TO (%s)",
pq.QuoteIdentifier(name),
pq.QuoteLiteral(monthStart.Format("2006-01-02")),
pq.QuoteLiteral(nextMonth.Format("2006-01-02")),
)
_, err := r.sql.ExecContext(ctx, query)
return err
}
func truncateToDay(t time.Time) time.Time {
return timezone.StartOfDay(t)
}
func truncateToMonthUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
}
package repository
import (
"context"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const dashboardStatsCacheKey = "dashboard:stats:v1"
type dashboardCache struct {
rdb *redis.Client
keyPrefix string
}
func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardStatsCache {
prefix := "sub2api:"
if cfg != nil {
prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
}
if prefix != "" && !strings.HasSuffix(prefix, ":") {
prefix += ":"
}
return &dashboardCache{
rdb: rdb,
keyPrefix: prefix,
}
}
func (c *dashboardCache) GetDashboardStats(ctx context.Context) (string, error) {
val, err := c.rdb.Get(ctx, c.buildKey()).Result()
if err != nil {
if err == redis.Nil {
return "", service.ErrDashboardStatsCacheMiss
}
return "", err
}
return val, nil
}
func (c *dashboardCache) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
return c.rdb.Set(ctx, c.buildKey(), data, ttl).Err()
}
func (c *dashboardCache) buildKey() string {
if c.keyPrefix == "" {
return dashboardStatsCacheKey
}
return c.keyPrefix + dashboardStatsCacheKey
}
func (c *dashboardCache) DeleteDashboardStats(ctx context.Context) error {
return c.rdb.Del(ctx, c.buildKey()).Err()
}
package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestNewDashboardCacheKeyPrefix(t *testing.T) {
cache := NewDashboardCache(nil, &config.Config{
Dashboard: config.DashboardCacheConfig{
KeyPrefix: "prod",
},
})
impl, ok := cache.(*dashboardCache)
require.True(t, ok)
require.Equal(t, "prod:", impl.keyPrefix)
cache = NewDashboardCache(nil, &config.Config{
Dashboard: config.DashboardCacheConfig{
KeyPrefix: "staging:",
},
})
impl, ok = cache.(*dashboardCache)
require.True(t, ok)
require.Equal(t, "staging:", impl.keyPrefix)
}
...@@ -11,8 +11,8 @@ import ( ...@@ -11,8 +11,8 @@ import (
) )
const ( const (
geminiTokenKeyPrefix = "gemini:token:" oauthTokenKeyPrefix = "oauth:token:"
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:" oauthRefreshLockKeyPrefix = "oauth:refresh_lock:"
) )
type geminiTokenCache struct { type geminiTokenCache struct {
...@@ -24,21 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache { ...@@ -24,21 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
} }
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) { func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Get(ctx, key).Result() return c.rdb.Get(ctx, key).Result()
} }
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error { func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Set(ctx, key, token, ttl).Err() return c.rdb.Set(ctx, key, token, ttl).Err()
} }
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) { func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result() return c.rdb.SetNX(ctx, key, 1, ttl).Result()
} }
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error { func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey) key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment