Commit 6901b64f authored by cyhhao's avatar cyhhao
Browse files

merge: sync upstream changes

parents 32c47b15 dae0d532
...@@ -14,6 +14,8 @@ type OpsRepository interface { ...@@ -14,6 +14,8 @@ type OpsRepository interface {
InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error) InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error) GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error)
ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error)
UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error
// Lightweight window stats (for realtime WS / quick sampling). // Lightweight window stats (for realtime WS / quick sampling).
GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error) GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error)
...@@ -39,12 +41,17 @@ type OpsRepository interface { ...@@ -39,12 +41,17 @@ type OpsRepository interface {
DeleteAlertRule(ctx context.Context, id int64) error DeleteAlertRule(ctx context.Context, id int64) error
ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error) ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error)
GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error)
GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error)
UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error
UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error
// Alert silences
CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error)
IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error)
// Pre-aggregation (hourly/daily) used for long-window dashboard performance. // Pre-aggregation (hourly/daily) used for long-window dashboard performance.
UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error
UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error
...@@ -91,7 +98,6 @@ type OpsInsertErrorLogInput struct { ...@@ -91,7 +98,6 @@ type OpsInsertErrorLogInput struct {
// It is set by OpsService.RecordError before persisting. // It is set by OpsService.RecordError before persisting.
UpstreamErrorsJSON *string UpstreamErrorsJSON *string
DurationMs *int
TimeToFirstTokenMs *int64 TimeToFirstTokenMs *int64
RequestBodyJSON *string // sanitized json string (not raw bytes) RequestBodyJSON *string // sanitized json string (not raw bytes)
...@@ -124,7 +130,15 @@ type OpsUpdateRetryAttemptInput struct { ...@@ -124,7 +130,15 @@ type OpsUpdateRetryAttemptInput struct {
FinishedAt time.Time FinishedAt time.Time
DurationMs int64 DurationMs int64
// Optional correlation // Persisted execution results (best-effort)
Success *bool
HTTPStatusCode *int
UpstreamRequestID *string
UsedAccountID *int64
ResponsePreview *string
ResponseTruncated *bool
// Optional correlation (legacy fields kept)
ResultRequestID *string ResultRequestID *string
ResultErrorID *int64 ResultErrorID *int64
...@@ -221,6 +235,9 @@ type OpsUpsertJobHeartbeatInput struct { ...@@ -221,6 +235,9 @@ type OpsUpsertJobHeartbeatInput struct {
LastErrorAt *time.Time LastErrorAt *time.Time
LastError *string LastError *string
LastDurationMs *int64 LastDurationMs *int64
// LastResult is an optional human-readable summary of the last successful run.
LastResult *string
} }
type OpsJobHeartbeat struct { type OpsJobHeartbeat struct {
...@@ -231,6 +248,7 @@ type OpsJobHeartbeat struct { ...@@ -231,6 +248,7 @@ type OpsJobHeartbeat struct {
LastErrorAt *time.Time `json:"last_error_at"` LastErrorAt *time.Time `json:"last_error_at"`
LastError *string `json:"last_error"` LastError *string `json:"last_error"`
LastDurationMs *int64 `json:"last_duration_ms"` LastDurationMs *int64 `json:"last_duration_ms"`
LastResult *string `json:"last_result"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }
......
...@@ -108,6 +108,10 @@ func (w *limitedResponseWriter) truncated() bool { ...@@ -108,6 +108,10 @@ func (w *limitedResponseWriter) truncated() bool {
return w.totalWritten > int64(w.limit) return w.totalWritten > int64(w.limit)
} }
const (
OpsRetryModeUpstreamEvent = "upstream_event"
)
func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, errorID int64, mode string, pinnedAccountID *int64) (*OpsRetryResult, error) { func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, errorID int64, mode string, pinnedAccountID *int64) (*OpsRetryResult, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil { if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err return nil, err
...@@ -123,6 +127,81 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er ...@@ -123,6 +127,81 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_MODE", "mode must be client or upstream") return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_MODE", "mode must be client or upstream")
} }
errorLog, err := s.GetErrorLogByID(ctx, errorID)
if err != nil {
return nil, err
}
if errorLog == nil {
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
if strings.TrimSpace(errorLog.RequestBody) == "" {
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
}
var pinned *int64
if mode == OpsRetryModeUpstream {
if pinnedAccountID != nil && *pinnedAccountID > 0 {
pinned = pinnedAccountID
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
pinned = errorLog.AccountID
} else {
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
}
}
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, mode, mode, pinned, errorLog)
}
// RetryUpstreamEvent retries a specific upstream attempt captured inside ops_error_logs.upstream_errors.
// idx is 0-based. It always pins the original event account_id.
func (s *OpsService) RetryUpstreamEvent(ctx context.Context, requestedByUserID int64, errorID int64, idx int) (*OpsRetryResult, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if idx < 0 {
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_UPSTREAM_IDX", "invalid upstream idx")
}
errorLog, err := s.GetErrorLogByID(ctx, errorID)
if err != nil {
return nil, err
}
if errorLog == nil {
return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
events, err := ParseOpsUpstreamErrors(errorLog.UpstreamErrors)
if err != nil {
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENTS_INVALID", "invalid upstream_errors")
}
if idx >= len(events) {
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_IDX_OOB", "upstream idx out of range")
}
ev := events[idx]
if ev == nil {
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENT_MISSING", "upstream event missing")
}
if ev.AccountID <= 0 {
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
}
upstreamBody := strings.TrimSpace(ev.UpstreamRequestBody)
if upstreamBody == "" {
return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_NO_REQUEST_BODY", "No upstream request body found to retry")
}
override := *errorLog
override.RequestBody = upstreamBody
pinned := ev.AccountID
// Persist as upstream_event, execute as upstream pinned retry.
return s.retryWithErrorLog(ctx, requestedByUserID, errorID, OpsRetryModeUpstreamEvent, OpsRetryModeUpstream, &pinned, &override)
}
func (s *OpsService) retryWithErrorLog(ctx context.Context, requestedByUserID int64, errorID int64, mode string, execMode string, pinnedAccountID *int64, errorLog *OpsErrorLogDetail) (*OpsRetryResult, error) {
latest, err := s.opsRepo.GetLatestRetryAttemptForError(ctx, errorID) latest, err := s.opsRepo.GetLatestRetryAttemptForError(ctx, errorID)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, infraerrors.InternalServer("OPS_RETRY_LOAD_LATEST_FAILED", "Failed to check retry status").WithCause(err) return nil, infraerrors.InternalServer("OPS_RETRY_LOAD_LATEST_FAILED", "Failed to check retry status").WithCause(err)
...@@ -144,22 +223,18 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er ...@@ -144,22 +223,18 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
} }
} }
errorLog, err := s.GetErrorLogByID(ctx, errorID) if errorLog == nil || strings.TrimSpace(errorLog.RequestBody) == "" {
if err != nil {
return nil, err
}
if strings.TrimSpace(errorLog.RequestBody) == "" {
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry") return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
} }
var pinned *int64 var pinned *int64
if mode == OpsRetryModeUpstream { if execMode == OpsRetryModeUpstream {
if pinnedAccountID != nil && *pinnedAccountID > 0 { if pinnedAccountID != nil && *pinnedAccountID > 0 {
pinned = pinnedAccountID pinned = pinnedAccountID
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 { } else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
pinned = errorLog.AccountID pinned = errorLog.AccountID
} else { } else {
return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry") return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
} }
} }
...@@ -196,7 +271,7 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er ...@@ -196,7 +271,7 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
execCtx, cancel := context.WithTimeout(ctx, opsRetryTimeout) execCtx, cancel := context.WithTimeout(ctx, opsRetryTimeout)
defer cancel() defer cancel()
execRes := s.executeRetry(execCtx, errorLog, mode, pinned) execRes := s.executeRetry(execCtx, errorLog, execMode, pinned)
finishedAt := time.Now() finishedAt := time.Now()
result.FinishedAt = finishedAt result.FinishedAt = finishedAt
...@@ -220,27 +295,40 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er ...@@ -220,27 +295,40 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
msg := result.ErrorMessage msg := result.ErrorMessage
updateErrMsg = &msg updateErrMsg = &msg
} }
// Keep legacy result_request_id empty; use upstream_request_id instead.
var resultRequestID *string var resultRequestID *string
if strings.TrimSpace(result.UpstreamRequestID) != "" {
v := result.UpstreamRequestID
resultRequestID = &v
}
finalStatus := result.Status finalStatus := result.Status
if strings.TrimSpace(finalStatus) == "" { if strings.TrimSpace(finalStatus) == "" {
finalStatus = opsRetryStatusFailed finalStatus = opsRetryStatusFailed
} }
success := strings.EqualFold(finalStatus, opsRetryStatusSucceeded)
httpStatus := result.HTTPStatusCode
upstreamReqID := result.UpstreamRequestID
usedAccountID := result.UsedAccountID
preview := result.ResponsePreview
truncated := result.ResponseTruncated
if err := s.opsRepo.UpdateRetryAttempt(updateCtx, &OpsUpdateRetryAttemptInput{ if err := s.opsRepo.UpdateRetryAttempt(updateCtx, &OpsUpdateRetryAttemptInput{
ID: attemptID, ID: attemptID,
Status: finalStatus, Status: finalStatus,
FinishedAt: finishedAt, FinishedAt: finishedAt,
DurationMs: result.DurationMs, DurationMs: result.DurationMs,
ResultRequestID: resultRequestID, Success: &success,
ErrorMessage: updateErrMsg, HTTPStatusCode: &httpStatus,
UpstreamRequestID: &upstreamReqID,
UsedAccountID: usedAccountID,
ResponsePreview: &preview,
ResponseTruncated: &truncated,
ResultRequestID: resultRequestID,
ErrorMessage: updateErrMsg,
}); err != nil { }); err != nil {
// Best-effort: retry itself already executed; do not fail the API response.
log.Printf("[Ops] UpdateRetryAttempt failed: %v", err) log.Printf("[Ops] UpdateRetryAttempt failed: %v", err)
} else if success {
if err := s.opsRepo.UpdateErrorResolution(updateCtx, errorID, true, &requestedByUserID, &attemptID, &finishedAt); err != nil {
log.Printf("[Ops] UpdateErrorResolution failed: %v", err)
}
} }
return result, nil return result, nil
...@@ -426,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry ...@@ -426,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
if s.gatewayService == nil { if s.gatewayService == nil {
return nil, fmt.Errorf("gateway service not available") return nil, fmt.Errorf("gateway service not available")
} }
return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs) return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
default: default:
return nil, fmt.Errorf("unsupported retry type: %s", reqType) return nil, fmt.Errorf("unsupported retry type: %s", reqType)
} }
......
...@@ -177,6 +177,10 @@ func (s *OpsScheduledReportService) runOnce() { ...@@ -177,6 +177,10 @@ func (s *OpsScheduledReportService) runOnce() {
return return
} }
reportsTotal := len(reports)
reportsDue := 0
sentAttempts := 0
for _, report := range reports { for _, report := range reports {
if report == nil || !report.Enabled { if report == nil || !report.Enabled {
continue continue
...@@ -184,14 +188,18 @@ func (s *OpsScheduledReportService) runOnce() { ...@@ -184,14 +188,18 @@ func (s *OpsScheduledReportService) runOnce() {
if report.NextRunAt.After(now) { if report.NextRunAt.After(now) {
continue continue
} }
reportsDue++
if err := s.runReport(ctx, report, now); err != nil { attempts, err := s.runReport(ctx, report, now)
if err != nil {
s.recordHeartbeatError(runAt, time.Since(startedAt), err) s.recordHeartbeatError(runAt, time.Since(startedAt), err)
return return
} }
sentAttempts += attempts
} }
s.recordHeartbeatSuccess(runAt, time.Since(startedAt)) result := truncateString(fmt.Sprintf("reports=%d due=%d send_attempts=%d", reportsTotal, reportsDue, sentAttempts), 2048)
s.recordHeartbeatSuccess(runAt, time.Since(startedAt), result)
} }
type opsScheduledReport struct { type opsScheduledReport struct {
...@@ -297,9 +305,9 @@ func (s *OpsScheduledReportService) listScheduledReports(ctx context.Context, no ...@@ -297,9 +305,9 @@ func (s *OpsScheduledReportService) listScheduledReports(ctx context.Context, no
return out return out
} }
func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsScheduledReport, now time.Time) error { func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsScheduledReport, now time.Time) (int, error) {
if s == nil || s.opsService == nil || s.emailService == nil || report == nil { if s == nil || s.opsService == nil || s.emailService == nil || report == nil {
return nil return 0, nil
} }
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
...@@ -310,11 +318,11 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc ...@@ -310,11 +318,11 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
content, err := s.generateReportHTML(ctx, report, now) content, err := s.generateReportHTML(ctx, report, now)
if err != nil { if err != nil {
return err return 0, err
} }
if strings.TrimSpace(content) == "" { if strings.TrimSpace(content) == "" {
// Skip sending when the report decides not to emit content (e.g., digest below min count). // Skip sending when the report decides not to emit content (e.g., digest below min count).
return nil return 0, nil
} }
recipients := report.Recipients recipients := report.Recipients
...@@ -325,22 +333,24 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc ...@@ -325,22 +333,24 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
} }
} }
if len(recipients) == 0 { if len(recipients) == 0 {
return nil return 0, nil
} }
subject := fmt.Sprintf("[Ops Report] %s", strings.TrimSpace(report.Name)) subject := fmt.Sprintf("[Ops Report] %s", strings.TrimSpace(report.Name))
attempts := 0
for _, to := range recipients { for _, to := range recipients {
addr := strings.TrimSpace(to) addr := strings.TrimSpace(to)
if addr == "" { if addr == "" {
continue continue
} }
attempts++
if err := s.emailService.SendEmail(ctx, addr, subject, content); err != nil { if err := s.emailService.SendEmail(ctx, addr, subject, content); err != nil {
// Ignore per-recipient failures; continue best-effort. // Ignore per-recipient failures; continue best-effort.
continue continue
} }
} }
return nil return attempts, nil
} }
func (s *OpsScheduledReportService) generateReportHTML(ctx context.Context, report *opsScheduledReport, now time.Time) (string, error) { func (s *OpsScheduledReportService) generateReportHTML(ctx context.Context, report *opsScheduledReport, now time.Time) (string, error) {
...@@ -650,7 +660,7 @@ func (s *OpsScheduledReportService) setLastRunAt(ctx context.Context, reportType ...@@ -650,7 +660,7 @@ func (s *OpsScheduledReportService) setLastRunAt(ctx context.Context, reportType
_ = s.redisClient.Set(ctx, key, strconv.FormatInt(t.UTC().Unix(), 10), 14*24*time.Hour).Err() _ = s.redisClient.Set(ctx, key, strconv.FormatInt(t.UTC().Unix(), 10), 14*24*time.Hour).Err()
} }
func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) { func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
if s == nil || s.opsService == nil || s.opsService.opsRepo == nil { if s == nil || s.opsService == nil || s.opsService.opsRepo == nil {
return return
} }
...@@ -658,11 +668,17 @@ func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, dura ...@@ -658,11 +668,17 @@ func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, dura
durMs := duration.Milliseconds() durMs := duration.Milliseconds()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
msg := strings.TrimSpace(result)
if msg == "" {
msg = "ok"
}
msg = truncateString(msg, 2048)
_ = s.opsService.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{ _ = s.opsService.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
JobName: opsScheduledReportJobName, JobName: opsScheduledReportJobName,
LastRunAt: &runAt, LastRunAt: &runAt,
LastSuccessAt: &now, LastSuccessAt: &now,
LastDurationMs: &durMs, LastDurationMs: &durMs,
LastResult: &msg,
}) })
} }
......
...@@ -208,6 +208,25 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn ...@@ -208,6 +208,25 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
out.Detail = "" out.Detail = ""
} }
out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
if out.UpstreamRequestBody != "" {
// Reuse the same sanitization/trimming strategy as request body storage.
// Keep it small so it is safe to persist in ops_error_logs JSON.
sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
if sanitized != "" {
out.UpstreamRequestBody = sanitized
if truncated {
out.Kind = strings.TrimSpace(out.Kind)
if out.Kind == "" {
out.Kind = "upstream"
}
out.Kind = out.Kind + ":request_body_truncated"
}
} else {
out.UpstreamRequestBody = ""
}
}
// Drop fully-empty events (can happen if only status code was known). // Drop fully-empty events (can happen if only status code was known).
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" { if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
continue continue
...@@ -236,7 +255,13 @@ func (s *OpsService) GetErrorLogs(ctx context.Context, filter *OpsErrorLogFilter ...@@ -236,7 +255,13 @@ func (s *OpsService) GetErrorLogs(ctx context.Context, filter *OpsErrorLogFilter
if s.opsRepo == nil { if s.opsRepo == nil {
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Total: 0, Page: 1, PageSize: 20}, nil return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Total: 0, Page: 1, PageSize: 20}, nil
} }
return s.opsRepo.ListErrorLogs(ctx, filter) result, err := s.opsRepo.ListErrorLogs(ctx, filter)
if err != nil {
log.Printf("[Ops] GetErrorLogs failed: %v", err)
return nil, err
}
return result, nil
} }
func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) { func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) {
...@@ -256,6 +281,46 @@ func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLo ...@@ -256,6 +281,46 @@ func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLo
return detail, nil return detail, nil
} }
func (s *OpsService) ListRetryAttemptsByErrorID(ctx context.Context, errorID int64, limit int) ([]*OpsRetryAttempt, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
}
if s.opsRepo == nil {
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if errorID <= 0 {
return nil, infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
}
items, err := s.opsRepo.ListRetryAttemptsByErrorID(ctx, errorID, limit)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return []*OpsRetryAttempt{}, nil
}
return nil, infraerrors.InternalServer("OPS_RETRY_LIST_FAILED", "Failed to list retry attempts").WithCause(err)
}
return items, nil
}
func (s *OpsService) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64) error {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return err
}
if s.opsRepo == nil {
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
}
if errorID <= 0 {
return infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
}
// Best-effort ensure the error exists
if _, err := s.opsRepo.GetErrorLogByID(ctx, errorID); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
}
return infraerrors.InternalServer("OPS_ERROR_LOAD_FAILED", "Failed to load ops error log").WithCause(err)
}
return s.opsRepo.UpdateErrorResolution(ctx, errorID, resolved, resolvedByUserID, resolvedRetryID, nil)
}
func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, truncated bool, bytesLen int) { func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, truncated bool, bytesLen int) {
bytesLen = len(raw) bytesLen = len(raw)
if len(raw) == 0 { if len(raw) == 0 {
...@@ -296,14 +361,34 @@ func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, tr ...@@ -296,14 +361,34 @@ func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, tr
} }
} }
// Last resort: store a minimal placeholder (still valid JSON). // Last resort: keep JSON shape but drop big fields.
placeholder := map[string]any{ // This avoids downstream code that expects certain top-level keys from crashing.
"request_body_truncated": true, if root, ok := decoded.(map[string]any); ok {
} placeholder := shallowCopyMap(root)
if model := extractString(decoded, "model"); model != "" { placeholder["request_body_truncated"] = true
placeholder["model"] = model
// Replace potentially huge arrays/strings, but keep the keys present.
for _, k := range []string{"messages", "contents", "input", "prompt"} {
if _, exists := placeholder[k]; exists {
placeholder[k] = []any{}
}
}
for _, k := range []string{"text"} {
if _, exists := placeholder[k]; exists {
placeholder[k] = ""
}
}
encoded4, err4 := json.Marshal(placeholder)
if err4 == nil {
if len(encoded4) <= maxBytes {
return string(encoded4), true, bytesLen
}
}
} }
encoded4, err4 := json.Marshal(placeholder)
// Final fallback: minimal valid JSON.
encoded4, err4 := json.Marshal(map[string]any{"request_body_truncated": true})
if err4 != nil { if err4 != nil {
return "", true, bytesLen return "", true, bytesLen
} }
...@@ -526,12 +611,3 @@ func sanitizeErrorBodyForStorage(raw string, maxBytes int) (sanitized string, tr ...@@ -526,12 +611,3 @@ func sanitizeErrorBodyForStorage(raw string, maxBytes int) (sanitized string, tr
} }
return raw, false return raw, false
} }
func extractString(v any, key string) string {
root, ok := v.(map[string]any)
if !ok {
return ""
}
s, _ := root[key].(string)
return strings.TrimSpace(s)
}
...@@ -368,9 +368,11 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings { ...@@ -368,9 +368,11 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
Aggregation: OpsAggregationSettings{ Aggregation: OpsAggregationSettings{
AggregationEnabled: false, AggregationEnabled: false,
}, },
IgnoreCountTokensErrors: false, IgnoreCountTokensErrors: false,
AutoRefreshEnabled: false, IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
AutoRefreshIntervalSec: 30, IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
AutoRefreshEnabled: false,
AutoRefreshIntervalSec: 30,
} }
} }
...@@ -482,13 +484,11 @@ const SettingKeyOpsMetricThresholds = "ops_metric_thresholds" ...@@ -482,13 +484,11 @@ const SettingKeyOpsMetricThresholds = "ops_metric_thresholds"
func defaultOpsMetricThresholds() *OpsMetricThresholds { func defaultOpsMetricThresholds() *OpsMetricThresholds {
slaMin := 99.5 slaMin := 99.5
latencyMax := 2000.0
ttftMax := 500.0 ttftMax := 500.0
reqErrMax := 5.0 reqErrMax := 5.0
upstreamErrMax := 5.0 upstreamErrMax := 5.0
return &OpsMetricThresholds{ return &OpsMetricThresholds{
SLAPercentMin: &slaMin, SLAPercentMin: &slaMin,
LatencyP99MsMax: &latencyMax,
TTFTp99MsMax: &ttftMax, TTFTp99MsMax: &ttftMax,
RequestErrorRatePercentMax: &reqErrMax, RequestErrorRatePercentMax: &reqErrMax,
UpstreamErrorRatePercentMax: &upstreamErrMax, UpstreamErrorRatePercentMax: &upstreamErrMax,
...@@ -538,9 +538,6 @@ func (s *OpsService) UpdateMetricThresholds(ctx context.Context, cfg *OpsMetricT ...@@ -538,9 +538,6 @@ func (s *OpsService) UpdateMetricThresholds(ctx context.Context, cfg *OpsMetricT
if cfg.SLAPercentMin != nil && (*cfg.SLAPercentMin < 0 || *cfg.SLAPercentMin > 100) { if cfg.SLAPercentMin != nil && (*cfg.SLAPercentMin < 0 || *cfg.SLAPercentMin > 100) {
return nil, errors.New("sla_percent_min must be between 0 and 100") return nil, errors.New("sla_percent_min must be between 0 and 100")
} }
if cfg.LatencyP99MsMax != nil && *cfg.LatencyP99MsMax < 0 {
return nil, errors.New("latency_p99_ms_max must be >= 0")
}
if cfg.TTFTp99MsMax != nil && *cfg.TTFTp99MsMax < 0 { if cfg.TTFTp99MsMax != nil && *cfg.TTFTp99MsMax < 0 {
return nil, errors.New("ttft_p99_ms_max must be >= 0") return nil, errors.New("ttft_p99_ms_max must be >= 0")
} }
......
...@@ -63,7 +63,6 @@ type OpsAlertSilencingSettings struct { ...@@ -63,7 +63,6 @@ type OpsAlertSilencingSettings struct {
type OpsMetricThresholds struct { type OpsMetricThresholds struct {
SLAPercentMin *float64 `json:"sla_percent_min,omitempty"` // SLA低于此值变红 SLAPercentMin *float64 `json:"sla_percent_min,omitempty"` // SLA低于此值变红
LatencyP99MsMax *float64 `json:"latency_p99_ms_max,omitempty"` // 延迟P99高于此值变红
TTFTp99MsMax *float64 `json:"ttft_p99_ms_max,omitempty"` // TTFT P99高于此值变红 TTFTp99MsMax *float64 `json:"ttft_p99_ms_max,omitempty"` // TTFT P99高于此值变红
RequestErrorRatePercentMax *float64 `json:"request_error_rate_percent_max,omitempty"` // 请求错误率高于此值变红 RequestErrorRatePercentMax *float64 `json:"request_error_rate_percent_max,omitempty"` // 请求错误率高于此值变红
UpstreamErrorRatePercentMax *float64 `json:"upstream_error_rate_percent_max,omitempty"` // 上游错误率高于此值变红 UpstreamErrorRatePercentMax *float64 `json:"upstream_error_rate_percent_max,omitempty"` // 上游错误率高于此值变红
...@@ -79,11 +78,13 @@ type OpsAlertRuntimeSettings struct { ...@@ -79,11 +78,13 @@ type OpsAlertRuntimeSettings struct {
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation). // OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
type OpsAdvancedSettings struct { type OpsAdvancedSettings struct {
DataRetention OpsDataRetentionSettings `json:"data_retention"` DataRetention OpsDataRetentionSettings `json:"data_retention"`
Aggregation OpsAggregationSettings `json:"aggregation"` Aggregation OpsAggregationSettings `json:"aggregation"`
IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"` IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"` IgnoreContextCanceled bool `json:"ignore_context_canceled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"` IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
} }
type OpsDataRetentionSettings struct { type OpsDataRetentionSettings struct {
......
...@@ -15,6 +15,11 @@ const ( ...@@ -15,6 +15,11 @@ const (
OpsUpstreamErrorMessageKey = "ops_upstream_error_message" OpsUpstreamErrorMessageKey = "ops_upstream_error_message"
OpsUpstreamErrorDetailKey = "ops_upstream_error_detail" OpsUpstreamErrorDetailKey = "ops_upstream_error_detail"
OpsUpstreamErrorsKey = "ops_upstream_errors" OpsUpstreamErrorsKey = "ops_upstream_errors"
// Best-effort capture of the current upstream request body so ops can
// retry the specific upstream attempt (not just the client request).
// This value is sanitized+trimmed before being persisted.
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
) )
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) { func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
...@@ -38,13 +43,21 @@ type OpsUpstreamErrorEvent struct { ...@@ -38,13 +43,21 @@ type OpsUpstreamErrorEvent struct {
AtUnixMs int64 `json:"at_unix_ms,omitempty"` AtUnixMs int64 `json:"at_unix_ms,omitempty"`
// Context // Context
Platform string `json:"platform,omitempty"` Platform string `json:"platform,omitempty"`
AccountID int64 `json:"account_id,omitempty"` AccountID int64 `json:"account_id,omitempty"`
AccountName string `json:"account_name,omitempty"`
// Outcome // Outcome
UpstreamStatusCode int `json:"upstream_status_code,omitempty"` UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
UpstreamRequestID string `json:"upstream_request_id,omitempty"` UpstreamRequestID string `json:"upstream_request_id,omitempty"`
// Best-effort upstream request capture (sanitized+trimmed).
// Required for retrying a specific upstream attempt.
UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
// Best-effort upstream response capture (sanitized+trimmed).
UpstreamResponseBody string `json:"upstream_response_body,omitempty"`
// Kind: http_error | request_error | retry_exhausted | failover // Kind: http_error | request_error | retry_exhausted | failover
Kind string `json:"kind,omitempty"` Kind string `json:"kind,omitempty"`
...@@ -61,6 +74,8 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { ...@@ -61,6 +74,8 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
} }
ev.Platform = strings.TrimSpace(ev.Platform) ev.Platform = strings.TrimSpace(ev.Platform)
ev.UpstreamRequestID = strings.TrimSpace(ev.UpstreamRequestID) ev.UpstreamRequestID = strings.TrimSpace(ev.UpstreamRequestID)
ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
ev.Kind = strings.TrimSpace(ev.Kind) ev.Kind = strings.TrimSpace(ev.Kind)
ev.Message = strings.TrimSpace(ev.Message) ev.Message = strings.TrimSpace(ev.Message)
ev.Detail = strings.TrimSpace(ev.Detail) ev.Detail = strings.TrimSpace(ev.Detail)
...@@ -68,6 +83,16 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { ...@@ -68,6 +83,16 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
ev.Message = sanitizeUpstreamErrorMessage(ev.Message) ev.Message = sanitizeUpstreamErrorMessage(ev.Message)
} }
// If the caller didn't explicitly pass upstream request body but the gateway
// stored it on the context, attach it so ops can retry this specific attempt.
if ev.UpstreamRequestBody == "" {
if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok {
if s, ok := v.(string); ok {
ev.UpstreamRequestBody = strings.TrimSpace(s)
}
}
}
var existing []*OpsUpstreamErrorEvent var existing []*OpsUpstreamErrorEvent
if v, ok := c.Get(OpsUpstreamErrorsKey); ok { if v, ok := c.Get(OpsUpstreamErrorsKey); ok {
if arr, ok := v.([]*OpsUpstreamErrorEvent); ok { if arr, ok := v.([]*OpsUpstreamErrorEvent); ok {
...@@ -92,3 +117,15 @@ func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string { ...@@ -92,3 +117,15 @@ func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {
s := string(raw) s := string(raw)
return &s return &s
} }
func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return []*OpsUpstreamErrorEvent{}, nil
}
var out []*OpsUpstreamErrorEvent
if err := json.Unmarshal([]byte(raw), &out); err != nil {
return nil, err
}
return out, nil
}
...@@ -31,5 +31,21 @@ func (p *Proxy) URL() string { ...@@ -31,5 +31,21 @@ func (p *Proxy) URL() string {
type ProxyWithAccountCount struct { type ProxyWithAccountCount struct {
Proxy Proxy
AccountCount int64 AccountCount int64
LatencyMs *int64
LatencyStatus string
LatencyMessage string
IPAddress string
Country string
CountryCode string
Region string
City string
}
type ProxyAccountSummary struct {
ID int64
Name string
Platform string
Type string
Notes *string
} }
package service
import (
"context"
"time"
)
type ProxyLatencyInfo struct {
Success bool `json:"success"`
LatencyMs *int64 `json:"latency_ms,omitempty"`
Message string `json:"message,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
Country string `json:"country,omitempty"`
CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
}
type ProxyLatencyCache interface {
GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*ProxyLatencyInfo, error)
SetProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) error
}
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
var ( var (
ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found") ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found")
ErrProxyInUse = infraerrors.Conflict("PROXY_IN_USE", "proxy is in use by accounts")
) )
type ProxyRepository interface { type ProxyRepository interface {
...@@ -26,6 +27,7 @@ type ProxyRepository interface { ...@@ -26,6 +27,7 @@ type ProxyRepository interface {
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
} }
// CreateProxyRequest 创建代理请求 // CreateProxyRequest 创建代理请求
......
...@@ -3,7 +3,7 @@ package service ...@@ -3,7 +3,7 @@ package service
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"log" "log/slog"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
...@@ -15,15 +15,16 @@ import ( ...@@ -15,15 +15,16 @@ import (
// RateLimitService 处理限流和过载状态管理 // RateLimitService 处理限流和过载状态管理
type RateLimitService struct { type RateLimitService struct {
accountRepo AccountRepository accountRepo AccountRepository
usageRepo UsageLogRepository usageRepo UsageLogRepository
cfg *config.Config cfg *config.Config
geminiQuotaService *GeminiQuotaService geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache tempUnschedCache TempUnschedCache
timeoutCounterCache TimeoutCounterCache timeoutCounterCache TimeoutCounterCache
settingService *SettingService settingService *SettingService
usageCacheMu sync.RWMutex tokenCacheInvalidator TokenCacheInvalidator
usageCache map[int64]*geminiUsageCacheEntry usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
} }
type geminiUsageCacheEntry struct { type geminiUsageCacheEntry struct {
...@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) { ...@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s.settingService = settingService s.settingService = settingService
} }
// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖)
func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvalidator) {
s.tokenCacheInvalidator = invalidator
}
// HandleUpstreamError 处理上游错误响应,标记账号状态 // HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度 // 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
...@@ -63,11 +69,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -63,11 +69,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载) // 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled() customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
if !account.ShouldHandleErrorCode(statusCode) { if !account.ShouldHandleErrorCode(statusCode) {
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode) slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return false return false
} }
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody) tempMatched := false
if statusCode != 401 {
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
}
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if upstreamMsg != "" { if upstreamMsg != "" {
...@@ -76,7 +85,25 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -76,7 +85,25 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
switch statusCode { switch statusCode {
case 401: case 401:
// 认证失败:停止调度,记录错误 // 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
if account.Type == AccountTypeOAuth {
// 1. 失效缓存
if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
}
}
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
if err := s.accountRepo.Update(ctx, account); err != nil {
slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
} else {
slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
}
}
msg := "Authentication failed (401): invalid or expired credentials" msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" { if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg msg = "Authentication failed (401): " + upstreamMsg
...@@ -100,7 +127,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -100,7 +127,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
s.handleAuthError(ctx, account, msg) s.handleAuthError(ctx, account, msg)
shouldDisable = true shouldDisable = true
case 429: case 429:
s.handle429(ctx, account, headers) s.handle429(ctx, account, headers, responseBody)
shouldDisable = false shouldDisable = false
case 529: case 529:
s.handle529(ctx, account) s.handle529(ctx, account)
...@@ -116,7 +143,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -116,7 +143,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
shouldDisable = true shouldDisable = true
} else if statusCode >= 500 { } else if statusCode >= 500 {
// 未启用自定义错误码时:仅记录5xx错误 // 未启用自定义错误码时:仅记录5xx错误
log.Printf("Account %d received upstream error %d", account.ID, statusCode) slog.Warn("account_upstream_error", "account_id", account.ID, "status_code", statusCode)
shouldDisable = false shouldDisable = false
} }
} }
...@@ -163,7 +190,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, ...@@ -163,7 +190,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start := geminiDailyWindowStart(now) start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now) totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok { if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID) stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
if err != nil { if err != nil {
return true, err return true, err
} }
...@@ -188,7 +215,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, ...@@ -188,7 +215,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
// NOTE: // NOTE:
// - This is a local precheck to reduce upstream 429s. // - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s. // - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt) slog.Info("gemini_precheck_daily_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
return false, nil return false, nil
} }
} }
...@@ -210,7 +237,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, ...@@ -210,7 +237,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if limit > 0 { if limit > 0 {
start := now.Truncate(time.Minute) start := now.Truncate(time.Minute)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID) stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
if err != nil { if err != nil {
return true, err return true, err
} }
...@@ -231,7 +258,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, ...@@ -231,7 +258,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if used >= limit { if used >= limit {
resetAt := start.Add(time.Minute) resetAt := start.Add(time.Minute)
// Do not persist "rate limited" status from local precheck. See note above. // Do not persist "rate limited" status from local precheck. See note above.
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt) slog.Info("gemini_precheck_minute_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
return false, nil return false, nil
} }
} }
...@@ -288,32 +315,40 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) ...@@ -288,32 +315,40 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
// handleAuthError 处理认证类错误(401/403),停止账号调度 // handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) { func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err) slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
return return
} }
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg) slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
} }
// handleCustomErrorCode 处理自定义错误码,停止账号调度 // handleCustomErrorCode 处理自定义错误码,停止账号调度
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) { func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil { if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err) slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
return return
} }
log.Printf("Account %d disabled due to custom error code %d: %s", account.ID, statusCode, errorMsg) slog.Warn("account_disabled_custom_error", "account_id", account.ID, "status_code", statusCode, "error", errorMsg)
} }
// handle429 处理429限流错误 // handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态 // 解析响应头获取重置时间,标记账号为限流状态
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header) { func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
// 解析重置时间戳 // 解析重置时间戳
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset") resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
if resetTimestamp == "" { if resetTimestamp == "" {
// 没有重置时间,使用默认5分钟 // 没有重置时间,使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute) resetAt := time.Now().Add(5 * time.Minute)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
} else {
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
}
return
}
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
} }
return return
} }
...@@ -321,19 +356,36 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head ...@@ -321,19 +356,36 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 解析Unix时间戳 // 解析Unix时间戳
ts, err := strconv.ParseInt(resetTimestamp, 10, 64) ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
if err != nil { if err != nil {
log.Printf("Parse reset timestamp failed: %v", err) slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err)
resetAt := time.Now().Add(5 * time.Minute) resetAt := time.Now().Add(5 * time.Minute)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
} else {
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
}
return
}
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
} }
return return
} }
resetAt := time.Unix(ts, 0) resetAt := time.Unix(ts, 0)
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
return
}
slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
return
}
// 标记限流状态 // 标记限流状态
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return return
} }
...@@ -341,10 +393,21 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head ...@@ -341,10 +393,21 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
windowEnd := resetAt windowEnd := resetAt
windowStart := resetAt.Add(-5 * time.Hour) windowStart := resetAt.Add(-5 * time.Hour)
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil { if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err) slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
} }
log.Printf("Account %d rate limited until %v", account.ID, resetAt) slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt)
}
func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, responseBody []byte) bool {
if account == nil || account.Platform != PlatformAnthropic {
return false
}
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(responseBody)))
if msg == "" {
return false
}
return strings.Contains(msg, "sonnet")
} }
// handle529 处理529过载错误 // handle529 处理529过载错误
...@@ -357,11 +420,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) { ...@@ -357,11 +420,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil { if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err) slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
return return
} }
log.Printf("Account %d overloaded until %v", account.ID, until) slog.Info("account_overloaded", "account_id", account.ID, "until", until)
} }
// UpdateSessionWindow 从成功响应更新5h窗口状态 // UpdateSessionWindow 从成功响应更新5h窗口状态
...@@ -384,17 +447,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc ...@@ -384,17 +447,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
end := start.Add(5 * time.Hour) end := start.Add(5 * time.Hour)
windowStart = &start windowStart = &start
windowEnd = &end windowEnd = &end
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status) slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
} }
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil { if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err) slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
} }
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态 // 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
if status == "allowed" && account.IsRateLimited() { if status == "allowed" && account.IsRateLimited() {
if err := s.ClearRateLimit(ctx, account.ID); err != nil { if err := s.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err) slog.Warn("rate_limit_clear_failed", "account_id", account.ID, "error", err)
} }
} }
} }
...@@ -404,7 +467,10 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) ...@@ -404,7 +467,10 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
if err := s.accountRepo.ClearRateLimit(ctx, accountID); err != nil { if err := s.accountRepo.ClearRateLimit(ctx, accountID); err != nil {
return err return err
} }
return s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID) if err := s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID); err != nil {
return err
}
return s.accountRepo.ClearModelRateLimits(ctx, accountID)
} }
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error { func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
...@@ -413,7 +479,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID ...@@ -413,7 +479,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
} }
if s.tempUnschedCache != nil { if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil { if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err) slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
} }
} }
return nil return nil
...@@ -460,7 +526,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i ...@@ -460,7 +526,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
if s.tempUnschedCache != nil { if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil { if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err) slog.Warn("temp_unsched_cache_set_failed", "account_id", accountID, "error", err)
} }
} }
...@@ -563,17 +629,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account ...@@ -563,17 +629,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
} }
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err) slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
return false return false
} }
if s.tempUnschedCache != nil { if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil { if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err) slog.Warn("temp_unsched_cache_set_failed", "account_id", account.ID, "error", err)
} }
} }
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode) slog.Info("account_temp_unschedulable", "account_id", account.ID, "until", until, "rule_index", ruleIndex, "status_code", statusCode)
return true return true
} }
...@@ -597,13 +663,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc ...@@ -597,13 +663,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
// 获取系统设置 // 获取系统设置
if s.settingService == nil { if s.settingService == nil {
log.Printf("[StreamTimeout] settingService not configured, skipping timeout handling for account %d", account.ID) slog.Warn("stream_timeout_setting_service_missing", "account_id", account.ID)
return false return false
} }
settings, err := s.settingService.GetStreamTimeoutSettings(ctx) settings, err := s.settingService.GetStreamTimeoutSettings(ctx)
if err != nil { if err != nil {
log.Printf("[StreamTimeout] Failed to get settings: %v", err) slog.Warn("stream_timeout_get_settings_failed", "account_id", account.ID, "error", err)
return false return false
} }
...@@ -620,14 +686,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc ...@@ -620,14 +686,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
if s.timeoutCounterCache != nil { if s.timeoutCounterCache != nil {
count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes) count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes)
if err != nil { if err != nil {
log.Printf("[StreamTimeout] Failed to increment timeout count for account %d: %v", account.ID, err) slog.Warn("stream_timeout_increment_count_failed", "account_id", account.ID, "error", err)
// 继续处理,使用 count=1 // 继续处理,使用 count=1
count = 1 count = 1
} }
} }
log.Printf("[StreamTimeout] Account %d timeout count: %d/%d (window: %d min, model: %s)", slog.Info("stream_timeout_count", "account_id", account.ID, "count", count, "threshold", settings.ThresholdCount, "window_minutes", settings.ThresholdWindowMinutes, "model", model)
account.ID, count, settings.ThresholdCount, settings.ThresholdWindowMinutes, model)
// 检查是否达到阈值 // 检查是否达到阈值
if count < int64(settings.ThresholdCount) { if count < int64(settings.ThresholdCount) {
...@@ -668,24 +733,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context, ...@@ -668,24 +733,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
} }
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("[StreamTimeout] SetTempUnschedulable failed for account %d: %v", account.ID, err) slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
return false return false
} }
if s.tempUnschedCache != nil { if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil { if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("[StreamTimeout] SetTempUnsched cache failed for account %d: %v", account.ID, err) slog.Warn("stream_timeout_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
} }
} }
// 重置超时计数 // 重置超时计数
if s.timeoutCounterCache != nil { if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil { if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err) slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
} }
} }
log.Printf("[StreamTimeout] Account %d marked as temp unschedulable until %v (model: %s)", account.ID, until, model) slog.Info("stream_timeout_temp_unschedulable", "account_id", account.ID, "until", until, "model", model)
return true return true
} }
...@@ -694,17 +759,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun ...@@ -694,17 +759,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("[StreamTimeout] SetError failed for account %d: %v", account.ID, err) slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
return false return false
} }
// 重置超时计数 // 重置超时计数
if s.timeoutCounterCache != nil { if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil { if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err) slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
} }
} }
log.Printf("[StreamTimeout] Account %d marked as error (model: %s)", account.ID, model) slog.Warn("stream_timeout_account_error", "account_id", account.ID, "model", model)
return true return true
} }
//go:build unit
package service
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini
setErrorCalls int
tempCalls int
lastErrorMsg string
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
r.lastErrorMsg = errorMsg
return nil
}
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
type tokenCacheInvalidatorRecorder struct {
accounts []*Account
err error
}
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
r.accounts = append(r.accounts, account)
return r.err
}
func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
tests := []struct {
name string
platform string
}{
{name: "gemini", platform: PlatformGemini},
{name: "antigravity", platform: PlatformAntigravity},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: tt.platform,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
},
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)")
require.Len(t, invalidator.accounts, 1)
})
}
}
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 101,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Len(t, invalidator.accounts, 1)
}
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 102,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts)
}
package service
import (
"context"
"time"
)
// SessionLimitCache 管理账号级别的活跃会话跟踪
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制
//
// Key 格式: session_limit:account:{accountID}
// 数据结构: Sorted Set (member=sessionUUID, score=timestamp)
//
// 会话在空闲超时后自动过期,无需手动清理
type SessionLimitCache interface {
// RegisterSession 注册会话活动
// - 如果会话已存在,刷新其时间戳并返回 true
// - 如果会话不存在且活跃会话数 < maxSessions,添加新会话并返回 true
// - 如果会话不存在且活跃会话数 >= maxSessions,返回 false(拒绝)
//
// 参数:
// accountID: 账号 ID
// sessionUUID: 从 metadata.user_id 中提取的会话 UUID
// maxSessions: 最大并发会话数限制
// idleTimeout: 会话空闲超时时间
//
// 返回:
// allowed: true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
// error: 操作错误
RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (allowed bool, err error)
// RefreshSession 刷新现有会话的时间戳
// 用于活跃会话保持活动状态
RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error
// GetActiveSessionCount 获取当前活跃会话数
// 返回未过期的会话数量
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// 返回 map[accountID]count,查询失败的账号不在 map 中
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
// IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
// ========== 5h窗口费用缓存 ==========
// Key 格式: window_cost:account:{accountID}
// 用于缓存账号在当前5h窗口内的标准费用,减少数据库聚合查询压力
// GetWindowCost 获取缓存的窗口费用
// 返回 (cost, true, nil) 如果缓存命中
// 返回 (0, false, nil) 如果缓存未命中
// 返回 (0, false, err) 如果发生错误
GetWindowCost(ctx context.Context, accountID int64) (cost float64, hit bool, err error)
// SetWindowCost 设置窗口费用缓存
SetWindowCost(ctx context.Context, accountID int64, cost float64) error
// GetWindowCostBatch 批量获取窗口费用缓存
// 返回 map[accountID]cost,缓存未命中的账号不在 map 中
GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error)
}
package service package service
import ( import (
"fmt"
"log" "log"
"sync" "sync"
"time" "time"
...@@ -8,6 +9,8 @@ import ( ...@@ -8,6 +9,8 @@ import (
"github.com/zeromicro/go-zero/core/collection" "github.com/zeromicro/go-zero/core/collection"
) )
var newTimingWheel = collection.NewTimingWheel
// TimingWheelService wraps go-zero's TimingWheel for task scheduling // TimingWheelService wraps go-zero's TimingWheel for task scheduling
type TimingWheelService struct { type TimingWheelService struct {
tw *collection.TimingWheel tw *collection.TimingWheel
...@@ -15,18 +18,18 @@ type TimingWheelService struct { ...@@ -15,18 +18,18 @@ type TimingWheelService struct {
} }
// NewTimingWheelService creates a new TimingWheelService instance // NewTimingWheelService creates a new TimingWheelService instance
func NewTimingWheelService() *TimingWheelService { func NewTimingWheelService() (*TimingWheelService, error) {
// 1 second tick, 3600 slots = supports up to 1 hour delay // 1 second tick, 3600 slots = supports up to 1 hour delay
// execute function: runs func() type tasks // execute function: runs func() type tasks
tw, err := collection.NewTimingWheel(1*time.Second, 3600, func(key, value any) { tw, err := newTimingWheel(1*time.Second, 3600, func(key, value any) {
if fn, ok := value.(func()); ok { if fn, ok := value.(func()); ok {
fn() fn()
} }
}) })
if err != nil { if err != nil {
panic(err) return nil, fmt.Errorf("创建 timing wheel 失败: %w", err)
} }
return &TimingWheelService{tw: tw} return &TimingWheelService{tw: tw}, nil
} }
// Start starts the timing wheel // Start starts the timing wheel
......
package service
import (
"errors"
"sync/atomic"
"testing"
"time"
"github.com/zeromicro/go-zero/core/collection"
)
func TestNewTimingWheelService_InitFail_NoPanicAndReturnError(t *testing.T) {
original := newTimingWheel
t.Cleanup(func() { newTimingWheel = original })
newTimingWheel = func(_ time.Duration, _ int, _ collection.Execute) (*collection.TimingWheel, error) {
return nil, errors.New("boom")
}
svc, err := NewTimingWheelService()
if err == nil {
t.Fatalf("期望返回 error,但得到 nil")
}
if svc != nil {
t.Fatalf("期望返回 nil svc,但得到非空")
}
}
func TestNewTimingWheelService_Success(t *testing.T) {
svc, err := NewTimingWheelService()
if err != nil {
t.Fatalf("期望 err 为 nil,但得到: %v", err)
}
if svc == nil {
t.Fatalf("期望 svc 非空,但得到 nil")
}
svc.Stop()
}
func TestNewTimingWheelService_ExecuteCallbackRunsFunc(t *testing.T) {
original := newTimingWheel
t.Cleanup(func() { newTimingWheel = original })
var captured collection.Execute
newTimingWheel = func(interval time.Duration, numSlots int, execute collection.Execute) (*collection.TimingWheel, error) {
captured = execute
return original(interval, numSlots, execute)
}
svc, err := NewTimingWheelService()
if err != nil {
t.Fatalf("期望 err 为 nil,但得到: %v", err)
}
if captured == nil {
t.Fatalf("期望 captured 非空,但得到 nil")
}
called := false
captured("k", func() { called = true })
if !called {
t.Fatalf("期望 execute 回调触发传入函数执行")
}
svc.Stop()
}
func TestTimingWheelService_Schedule_ExecutesOnce(t *testing.T) {
original := newTimingWheel
t.Cleanup(func() { newTimingWheel = original })
newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
return original(10*time.Millisecond, 128, execute)
}
svc, err := NewTimingWheelService()
if err != nil {
t.Fatalf("期望 err 为 nil,但得到: %v", err)
}
defer svc.Stop()
ch := make(chan struct{}, 1)
svc.Schedule("once", 30*time.Millisecond, func() { ch <- struct{}{} })
select {
case <-ch:
case <-time.After(500 * time.Millisecond):
t.Fatalf("等待任务执行超时")
}
select {
case <-ch:
t.Fatalf("任务不应重复执行")
case <-time.After(80 * time.Millisecond):
}
}
func TestTimingWheelService_Cancel_PreventsExecution(t *testing.T) {
original := newTimingWheel
t.Cleanup(func() { newTimingWheel = original })
newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
return original(10*time.Millisecond, 128, execute)
}
svc, err := NewTimingWheelService()
if err != nil {
t.Fatalf("期望 err 为 nil,但得到: %v", err)
}
defer svc.Stop()
ch := make(chan struct{}, 1)
svc.Schedule("cancel", 80*time.Millisecond, func() { ch <- struct{}{} })
svc.Cancel("cancel")
select {
case <-ch:
t.Fatalf("任务已取消,不应执行")
case <-time.After(200 * time.Millisecond):
}
}
func TestTimingWheelService_ScheduleRecurring_ExecutesMultipleTimes(t *testing.T) {
original := newTimingWheel
t.Cleanup(func() { newTimingWheel = original })
newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
return original(10*time.Millisecond, 128, execute)
}
svc, err := NewTimingWheelService()
if err != nil {
t.Fatalf("期望 err 为 nil,但得到: %v", err)
}
defer svc.Stop()
var count int32
svc.ScheduleRecurring("rec", 30*time.Millisecond, func() { atomic.AddInt32(&count, 1) })
deadline := time.Now().Add(500 * time.Millisecond)
for atomic.LoadInt32(&count) < 2 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
if atomic.LoadInt32(&count) < 2 {
t.Fatalf("期望周期任务至少执行 2 次,但只执行了 %d 次", atomic.LoadInt32(&count))
}
}
package service
import "context"
type TokenCacheInvalidator interface {
InvalidateToken(ctx context.Context, account *Account) error
}
type CompositeTokenCacheInvalidator struct {
cache GeminiTokenCache // 统一使用一个缓存接口,通过缓存键前缀区分平台
}
func NewCompositeTokenCacheInvalidator(cache GeminiTokenCache) *CompositeTokenCacheInvalidator {
return &CompositeTokenCacheInvalidator{
cache: cache,
}
}
func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error {
if c == nil || c.cache == nil || account == nil {
return nil
}
if account.Type != AccountTypeOAuth {
return nil
}
var cacheKey string
switch account.Platform {
case PlatformGemini:
cacheKey = GeminiTokenCacheKey(account)
case PlatformAntigravity:
cacheKey = AntigravityTokenCacheKey(account)
case PlatformOpenAI:
cacheKey = OpenAITokenCacheKey(account)
case PlatformAnthropic:
cacheKey = ClaudeTokenCacheKey(account)
default:
return nil
}
return c.cache.DeleteAccessToken(ctx, cacheKey)
}
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type geminiTokenCacheStub struct {
deletedKeys []string
deleteErr error
}
func (s *geminiTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
return "", nil
}
func (s *geminiTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
return nil
}
func (s *geminiTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
s.deletedKeys = append(s.deletedKeys, cacheKey)
return s.deleteErr
}
func (s *geminiTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
return true, nil
}
func (s *geminiTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
return nil
}
func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 10,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"project_id": "project-x",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"gemini:project-x"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 99,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"project_id": "ag-project",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_OpenAI(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 500,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "openai-token",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"openai:account:500"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_Claude(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 600,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "claude-token",
},
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, []string{"claude:account:600"}, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
tests := []struct {
name string
account *Account
}{
{
name: "gemini_api_key",
account: &Account{
ID: 1,
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
},
},
{
name: "openai_api_key",
account: &Account{
ID: 2,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
},
},
{
name: "claude_api_key",
account: &Account{
ID: 3,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
},
},
{
name: "claude_setup_token",
account: &Account{
ID: 4,
Platform: PlatformAnthropic,
Type: AccountTypeSetupToken,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache.deletedKeys = nil
err := invalidator.InvalidateToken(context.Background(), tt.account)
require.NoError(t, err)
require.Empty(t, cache.deletedKeys)
})
}
}
func TestCompositeTokenCacheInvalidator_SkipUnsupportedPlatform(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
account := &Account{
ID: 100,
Platform: "unknown-platform",
Type: AccountTypeOAuth,
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
require.Empty(t, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
invalidator := NewCompositeTokenCacheInvalidator(nil)
account := &Account{
ID: 2,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
}
func TestCompositeTokenCacheInvalidator_NilAccount(t *testing.T) {
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
err := invalidator.InvalidateToken(context.Background(), nil)
require.NoError(t, err)
require.Empty(t, cache.deletedKeys)
}
func TestCompositeTokenCacheInvalidator_NilInvalidator(t *testing.T) {
var invalidator *CompositeTokenCacheInvalidator
account := &Account{
ID: 5,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
}
err := invalidator.InvalidateToken(context.Background(), account)
require.NoError(t, err)
}
func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
expectedErr := errors.New("redis connection failed")
cache := &geminiTokenCacheStub{deleteErr: expectedErr}
invalidator := NewCompositeTokenCacheInvalidator(cache)
tests := []struct {
name string
account *Account
}{
{
name: "openai_delete_error",
account: &Account{
ID: 700,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
},
},
{
name: "claude_delete_error",
account: &Account{
ID: 800,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := invalidator.InvalidateToken(context.Background(), tt.account)
require.Error(t, err)
require.Equal(t, expectedErr, err)
})
}
}
func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
// 测试所有平台的缓存键生成和删除
cache := &geminiTokenCacheStub{}
invalidator := NewCompositeTokenCacheInvalidator(cache)
accounts := []*Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "gemini-proj"}},
{ID: 2, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "ag-proj"}},
{ID: 3, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
{ID: 4, Platform: PlatformAnthropic, Type: AccountTypeOAuth},
}
expectedKeys := []string{
"gemini:gemini-proj",
"ag:ag-proj",
"openai:account:3",
"claude:account:4",
}
for _, acc := range accounts {
err := invalidator.InvalidateToken(context.Background(), acc)
require.NoError(t, err)
}
require.Equal(t, expectedKeys, cache.deletedKeys)
}
package service
import "strconv"
// OpenAITokenCacheKey 生成 OpenAI OAuth 账号的缓存键
// 格式: "openai:account:{account_id}"
func OpenAITokenCacheKey(account *Account) string {
return "openai:account:" + strconv.FormatInt(account.ID, 10)
}
// ClaudeTokenCacheKey 生成 Claude (Anthropic) OAuth 账号的缓存键
// 格式: "claude:account:{account_id}"
func ClaudeTokenCacheKey(account *Account) string {
return "claude:account:" + strconv.FormatInt(account.ID, 10)
}
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGeminiTokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "with_project_id",
account: &Account{
ID: 100,
Credentials: map[string]any{
"project_id": "my-project-123",
},
},
expected: "gemini:my-project-123",
},
{
name: "project_id_with_whitespace",
account: &Account{
ID: 101,
Credentials: map[string]any{
"project_id": " project-with-spaces ",
},
},
expected: "gemini:project-with-spaces",
},
{
name: "empty_project_id_fallback_to_account_id",
account: &Account{
ID: 102,
Credentials: map[string]any{
"project_id": "",
},
},
expected: "gemini:account:102",
},
{
name: "whitespace_only_project_id_fallback_to_account_id",
account: &Account{
ID: 103,
Credentials: map[string]any{
"project_id": " ",
},
},
expected: "gemini:account:103",
},
{
name: "no_project_id_key_fallback_to_account_id",
account: &Account{
ID: 104,
Credentials: map[string]any{},
},
expected: "gemini:account:104",
},
{
name: "nil_credentials_fallback_to_account_id",
account: &Account{
ID: 105,
Credentials: nil,
},
expected: "gemini:account:105",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GeminiTokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}
func TestAntigravityTokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "with_project_id",
account: &Account{
ID: 200,
Credentials: map[string]any{
"project_id": "ag-project-456",
},
},
expected: "ag:ag-project-456",
},
{
name: "project_id_with_whitespace",
account: &Account{
ID: 201,
Credentials: map[string]any{
"project_id": " ag-project-spaces ",
},
},
expected: "ag:ag-project-spaces",
},
{
name: "empty_project_id_fallback_to_account_id",
account: &Account{
ID: 202,
Credentials: map[string]any{
"project_id": "",
},
},
expected: "ag:account:202",
},
{
name: "whitespace_only_project_id_fallback_to_account_id",
account: &Account{
ID: 203,
Credentials: map[string]any{
"project_id": " ",
},
},
expected: "ag:account:203",
},
{
name: "no_project_id_key_fallback_to_account_id",
account: &Account{
ID: 204,
Credentials: map[string]any{},
},
expected: "ag:account:204",
},
{
name: "nil_credentials_fallback_to_account_id",
account: &Account{
ID: 205,
Credentials: nil,
},
expected: "ag:account:205",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := AntigravityTokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}
func TestOpenAITokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "basic_account",
account: &Account{
ID: 300,
},
expected: "openai:account:300",
},
{
name: "account_with_credentials",
account: &Account{
ID: 301,
Credentials: map[string]any{
"access_token": "test-token",
},
},
expected: "openai:account:301",
},
{
name: "account_id_zero",
account: &Account{
ID: 0,
},
expected: "openai:account:0",
},
{
name: "large_account_id",
account: &Account{
ID: 9999999999,
},
expected: "openai:account:9999999999",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := OpenAITokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}
func TestClaudeTokenCacheKey(t *testing.T) {
tests := []struct {
name string
account *Account
expected string
}{
{
name: "basic_account",
account: &Account{
ID: 400,
},
expected: "claude:account:400",
},
{
name: "account_with_credentials",
account: &Account{
ID: 401,
Credentials: map[string]any{
"access_token": "claude-token",
},
},
expected: "claude:account:401",
},
{
name: "account_id_zero",
account: &Account{
ID: 0,
},
expected: "claude:account:0",
},
{
name: "large_account_id",
account: &Account{
ID: 9999999999,
},
expected: "claude:account:9999999999",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ClaudeTokenCacheKey(tt.account)
require.Equal(t, tt.expected, result)
})
}
}
func TestCacheKeyUniqueness(t *testing.T) {
// 确保不同平台的缓存键不会冲突
account := &Account{ID: 123}
openaiKey := OpenAITokenCacheKey(account)
claudeKey := ClaudeTokenCacheKey(account)
require.NotEqual(t, openaiKey, claudeKey, "OpenAI and Claude cache keys should be different")
require.Contains(t, openaiKey, "openai:")
require.Contains(t, claudeKey, "claude:")
}
...@@ -14,9 +14,10 @@ import ( ...@@ -14,9 +14,10 @@ import (
// TokenRefreshService OAuth token自动刷新服务 // TokenRefreshService OAuth token自动刷新服务
// 定期检查并刷新即将过期的token // 定期检查并刷新即将过期的token
type TokenRefreshService struct { type TokenRefreshService struct {
accountRepo AccountRepository accountRepo AccountRepository
refreshers []TokenRefresher refreshers []TokenRefresher
cfg *config.TokenRefreshConfig cfg *config.TokenRefreshConfig
cacheInvalidator TokenCacheInvalidator
stopCh chan struct{} stopCh chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
...@@ -29,12 +30,14 @@ func NewTokenRefreshService( ...@@ -29,12 +30,14 @@ func NewTokenRefreshService(
openaiOAuthService *OpenAIOAuthService, openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService, geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService, antigravityOAuthService *AntigravityOAuthService,
cacheInvalidator TokenCacheInvalidator,
cfg *config.Config, cfg *config.Config,
) *TokenRefreshService { ) *TokenRefreshService {
s := &TokenRefreshService{ s := &TokenRefreshService{
accountRepo: accountRepo, accountRepo: accountRepo,
cfg: &cfg.TokenRefresh, cfg: &cfg.TokenRefresh,
stopCh: make(chan struct{}), cacheInvalidator: cacheInvalidator,
stopCh: make(chan struct{}),
} }
// 注册平台特定的刷新器 // 注册平台特定的刷新器
...@@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc ...@@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if err := s.accountRepo.Update(ctx, account); err != nil { if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("failed to save credentials: %w", err) return fmt.Errorf("failed to save credentials: %w", err)
} }
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err)
} else {
log.Printf("[TokenRefresh] Token cache invalidated for account %d", account.ID)
}
}
return nil return nil
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment