Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
6901b64f
Commit
6901b64f
authored
Jan 17, 2026
by
cyhhao
Browse files
merge: sync upstream changes
parents
32c47b15
dae0d532
Changes
189
Hide whitespace changes
Inline
Side-by-side
backend/internal/service/ops_port.go
View file @
6901b64f
...
@@ -14,6 +14,8 @@ type OpsRepository interface {
...
@@ -14,6 +14,8 @@ type OpsRepository interface {
InsertRetryAttempt
(
ctx
context
.
Context
,
input
*
OpsInsertRetryAttemptInput
)
(
int64
,
error
)
InsertRetryAttempt
(
ctx
context
.
Context
,
input
*
OpsInsertRetryAttemptInput
)
(
int64
,
error
)
UpdateRetryAttempt
(
ctx
context
.
Context
,
input
*
OpsUpdateRetryAttemptInput
)
error
UpdateRetryAttempt
(
ctx
context
.
Context
,
input
*
OpsUpdateRetryAttemptInput
)
error
GetLatestRetryAttemptForError
(
ctx
context
.
Context
,
sourceErrorID
int64
)
(
*
OpsRetryAttempt
,
error
)
GetLatestRetryAttemptForError
(
ctx
context
.
Context
,
sourceErrorID
int64
)
(
*
OpsRetryAttempt
,
error
)
ListRetryAttemptsByErrorID
(
ctx
context
.
Context
,
sourceErrorID
int64
,
limit
int
)
([]
*
OpsRetryAttempt
,
error
)
UpdateErrorResolution
(
ctx
context
.
Context
,
errorID
int64
,
resolved
bool
,
resolvedByUserID
*
int64
,
resolvedRetryID
*
int64
,
resolvedAt
*
time
.
Time
)
error
// Lightweight window stats (for realtime WS / quick sampling).
// Lightweight window stats (for realtime WS / quick sampling).
GetWindowStats
(
ctx
context
.
Context
,
filter
*
OpsDashboardFilter
)
(
*
OpsWindowStats
,
error
)
GetWindowStats
(
ctx
context
.
Context
,
filter
*
OpsDashboardFilter
)
(
*
OpsWindowStats
,
error
)
...
@@ -39,12 +41,17 @@ type OpsRepository interface {
...
@@ -39,12 +41,17 @@ type OpsRepository interface {
DeleteAlertRule
(
ctx
context
.
Context
,
id
int64
)
error
DeleteAlertRule
(
ctx
context
.
Context
,
id
int64
)
error
ListAlertEvents
(
ctx
context
.
Context
,
filter
*
OpsAlertEventFilter
)
([]
*
OpsAlertEvent
,
error
)
ListAlertEvents
(
ctx
context
.
Context
,
filter
*
OpsAlertEventFilter
)
([]
*
OpsAlertEvent
,
error
)
GetAlertEventByID
(
ctx
context
.
Context
,
eventID
int64
)
(
*
OpsAlertEvent
,
error
)
GetActiveAlertEvent
(
ctx
context
.
Context
,
ruleID
int64
)
(
*
OpsAlertEvent
,
error
)
GetActiveAlertEvent
(
ctx
context
.
Context
,
ruleID
int64
)
(
*
OpsAlertEvent
,
error
)
GetLatestAlertEvent
(
ctx
context
.
Context
,
ruleID
int64
)
(
*
OpsAlertEvent
,
error
)
GetLatestAlertEvent
(
ctx
context
.
Context
,
ruleID
int64
)
(
*
OpsAlertEvent
,
error
)
CreateAlertEvent
(
ctx
context
.
Context
,
event
*
OpsAlertEvent
)
(
*
OpsAlertEvent
,
error
)
CreateAlertEvent
(
ctx
context
.
Context
,
event
*
OpsAlertEvent
)
(
*
OpsAlertEvent
,
error
)
UpdateAlertEventStatus
(
ctx
context
.
Context
,
eventID
int64
,
status
string
,
resolvedAt
*
time
.
Time
)
error
UpdateAlertEventStatus
(
ctx
context
.
Context
,
eventID
int64
,
status
string
,
resolvedAt
*
time
.
Time
)
error
UpdateAlertEventEmailSent
(
ctx
context
.
Context
,
eventID
int64
,
emailSent
bool
)
error
UpdateAlertEventEmailSent
(
ctx
context
.
Context
,
eventID
int64
,
emailSent
bool
)
error
// Alert silences
CreateAlertSilence
(
ctx
context
.
Context
,
input
*
OpsAlertSilence
)
(
*
OpsAlertSilence
,
error
)
IsAlertSilenced
(
ctx
context
.
Context
,
ruleID
int64
,
platform
string
,
groupID
*
int64
,
region
*
string
,
now
time
.
Time
)
(
bool
,
error
)
// Pre-aggregation (hourly/daily) used for long-window dashboard performance.
// Pre-aggregation (hourly/daily) used for long-window dashboard performance.
UpsertHourlyMetrics
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
)
error
UpsertHourlyMetrics
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
)
error
UpsertDailyMetrics
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
)
error
UpsertDailyMetrics
(
ctx
context
.
Context
,
startTime
,
endTime
time
.
Time
)
error
...
@@ -91,7 +98,6 @@ type OpsInsertErrorLogInput struct {
...
@@ -91,7 +98,6 @@ type OpsInsertErrorLogInput struct {
// It is set by OpsService.RecordError before persisting.
// It is set by OpsService.RecordError before persisting.
UpstreamErrorsJSON
*
string
UpstreamErrorsJSON
*
string
DurationMs
*
int
TimeToFirstTokenMs
*
int64
TimeToFirstTokenMs
*
int64
RequestBodyJSON
*
string
// sanitized json string (not raw bytes)
RequestBodyJSON
*
string
// sanitized json string (not raw bytes)
...
@@ -124,7 +130,15 @@ type OpsUpdateRetryAttemptInput struct {
...
@@ -124,7 +130,15 @@ type OpsUpdateRetryAttemptInput struct {
FinishedAt
time
.
Time
FinishedAt
time
.
Time
DurationMs
int64
DurationMs
int64
// Optional correlation
// Persisted execution results (best-effort)
Success
*
bool
HTTPStatusCode
*
int
UpstreamRequestID
*
string
UsedAccountID
*
int64
ResponsePreview
*
string
ResponseTruncated
*
bool
// Optional correlation (legacy fields kept)
ResultRequestID
*
string
ResultRequestID
*
string
ResultErrorID
*
int64
ResultErrorID
*
int64
...
@@ -221,6 +235,9 @@ type OpsUpsertJobHeartbeatInput struct {
...
@@ -221,6 +235,9 @@ type OpsUpsertJobHeartbeatInput struct {
LastErrorAt
*
time
.
Time
LastErrorAt
*
time
.
Time
LastError
*
string
LastError
*
string
LastDurationMs
*
int64
LastDurationMs
*
int64
// LastResult is an optional human-readable summary of the last successful run.
LastResult
*
string
}
}
type
OpsJobHeartbeat
struct
{
type
OpsJobHeartbeat
struct
{
...
@@ -231,6 +248,7 @@ type OpsJobHeartbeat struct {
...
@@ -231,6 +248,7 @@ type OpsJobHeartbeat struct {
LastErrorAt
*
time
.
Time
`json:"last_error_at"`
LastErrorAt
*
time
.
Time
`json:"last_error_at"`
LastError
*
string
`json:"last_error"`
LastError
*
string
`json:"last_error"`
LastDurationMs
*
int64
`json:"last_duration_ms"`
LastDurationMs
*
int64
`json:"last_duration_ms"`
LastResult
*
string
`json:"last_result"`
UpdatedAt
time
.
Time
`json:"updated_at"`
UpdatedAt
time
.
Time
`json:"updated_at"`
}
}
...
...
backend/internal/service/ops_retry.go
View file @
6901b64f
...
@@ -108,6 +108,10 @@ func (w *limitedResponseWriter) truncated() bool {
...
@@ -108,6 +108,10 @@ func (w *limitedResponseWriter) truncated() bool {
return
w
.
totalWritten
>
int64
(
w
.
limit
)
return
w
.
totalWritten
>
int64
(
w
.
limit
)
}
}
const
(
OpsRetryModeUpstreamEvent
=
"upstream_event"
)
func
(
s
*
OpsService
)
RetryError
(
ctx
context
.
Context
,
requestedByUserID
int64
,
errorID
int64
,
mode
string
,
pinnedAccountID
*
int64
)
(
*
OpsRetryResult
,
error
)
{
func
(
s
*
OpsService
)
RetryError
(
ctx
context
.
Context
,
requestedByUserID
int64
,
errorID
int64
,
mode
string
,
pinnedAccountID
*
int64
)
(
*
OpsRetryResult
,
error
)
{
if
err
:=
s
.
RequireMonitoringEnabled
(
ctx
);
err
!=
nil
{
if
err
:=
s
.
RequireMonitoringEnabled
(
ctx
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
...
@@ -123,6 +127,81 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
...
@@ -123,6 +127,81 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_INVALID_MODE"
,
"mode must be client or upstream"
)
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_INVALID_MODE"
,
"mode must be client or upstream"
)
}
}
errorLog
,
err
:=
s
.
GetErrorLogByID
(
ctx
,
errorID
)
if
err
!=
nil
{
return
nil
,
err
}
if
errorLog
==
nil
{
return
nil
,
infraerrors
.
NotFound
(
"OPS_ERROR_NOT_FOUND"
,
"ops error log not found"
)
}
if
strings
.
TrimSpace
(
errorLog
.
RequestBody
)
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_NO_REQUEST_BODY"
,
"No request body found to retry"
)
}
var
pinned
*
int64
if
mode
==
OpsRetryModeUpstream
{
if
pinnedAccountID
!=
nil
&&
*
pinnedAccountID
>
0
{
pinned
=
pinnedAccountID
}
else
if
errorLog
.
AccountID
!=
nil
&&
*
errorLog
.
AccountID
>
0
{
pinned
=
errorLog
.
AccountID
}
else
{
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_PINNED_ACCOUNT_REQUIRED"
,
"pinned_account_id is required for upstream retry"
)
}
}
return
s
.
retryWithErrorLog
(
ctx
,
requestedByUserID
,
errorID
,
mode
,
mode
,
pinned
,
errorLog
)
}
// RetryUpstreamEvent retries a specific upstream attempt captured inside ops_error_logs.upstream_errors.
// idx is 0-based. It always pins the original event account_id.
func
(
s
*
OpsService
)
RetryUpstreamEvent
(
ctx
context
.
Context
,
requestedByUserID
int64
,
errorID
int64
,
idx
int
)
(
*
OpsRetryResult
,
error
)
{
if
err
:=
s
.
RequireMonitoringEnabled
(
ctx
);
err
!=
nil
{
return
nil
,
err
}
if
s
.
opsRepo
==
nil
{
return
nil
,
infraerrors
.
ServiceUnavailable
(
"OPS_REPO_UNAVAILABLE"
,
"Ops repository not available"
)
}
if
idx
<
0
{
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_INVALID_UPSTREAM_IDX"
,
"invalid upstream idx"
)
}
errorLog
,
err
:=
s
.
GetErrorLogByID
(
ctx
,
errorID
)
if
err
!=
nil
{
return
nil
,
err
}
if
errorLog
==
nil
{
return
nil
,
infraerrors
.
NotFound
(
"OPS_ERROR_NOT_FOUND"
,
"ops error log not found"
)
}
events
,
err
:=
ParseOpsUpstreamErrors
(
errorLog
.
UpstreamErrors
)
if
err
!=
nil
{
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_UPSTREAM_EVENTS_INVALID"
,
"invalid upstream_errors"
)
}
if
idx
>=
len
(
events
)
{
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_UPSTREAM_IDX_OOB"
,
"upstream idx out of range"
)
}
ev
:=
events
[
idx
]
if
ev
==
nil
{
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_UPSTREAM_EVENT_MISSING"
,
"upstream event missing"
)
}
if
ev
.
AccountID
<=
0
{
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_PINNED_ACCOUNT_REQUIRED"
,
"account_id is required for upstream retry"
)
}
upstreamBody
:=
strings
.
TrimSpace
(
ev
.
UpstreamRequestBody
)
if
upstreamBody
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_UPSTREAM_NO_REQUEST_BODY"
,
"No upstream request body found to retry"
)
}
override
:=
*
errorLog
override
.
RequestBody
=
upstreamBody
pinned
:=
ev
.
AccountID
// Persist as upstream_event, execute as upstream pinned retry.
return
s
.
retryWithErrorLog
(
ctx
,
requestedByUserID
,
errorID
,
OpsRetryModeUpstreamEvent
,
OpsRetryModeUpstream
,
&
pinned
,
&
override
)
}
func
(
s
*
OpsService
)
retryWithErrorLog
(
ctx
context
.
Context
,
requestedByUserID
int64
,
errorID
int64
,
mode
string
,
execMode
string
,
pinnedAccountID
*
int64
,
errorLog
*
OpsErrorLogDetail
)
(
*
OpsRetryResult
,
error
)
{
latest
,
err
:=
s
.
opsRepo
.
GetLatestRetryAttemptForError
(
ctx
,
errorID
)
latest
,
err
:=
s
.
opsRepo
.
GetLatestRetryAttemptForError
(
ctx
,
errorID
)
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
return
nil
,
infraerrors
.
InternalServer
(
"OPS_RETRY_LOAD_LATEST_FAILED"
,
"Failed to check retry status"
)
.
WithCause
(
err
)
return
nil
,
infraerrors
.
InternalServer
(
"OPS_RETRY_LOAD_LATEST_FAILED"
,
"Failed to check retry status"
)
.
WithCause
(
err
)
...
@@ -144,22 +223,18 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
...
@@ -144,22 +223,18 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
}
}
}
}
errorLog
,
err
:=
s
.
GetErrorLogByID
(
ctx
,
errorID
)
if
errorLog
==
nil
||
strings
.
TrimSpace
(
errorLog
.
RequestBody
)
==
""
{
if
err
!=
nil
{
return
nil
,
err
}
if
strings
.
TrimSpace
(
errorLog
.
RequestBody
)
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_NO_REQUEST_BODY"
,
"No request body found to retry"
)
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_NO_REQUEST_BODY"
,
"No request body found to retry"
)
}
}
var
pinned
*
int64
var
pinned
*
int64
if
m
ode
==
OpsRetryModeUpstream
{
if
execM
ode
==
OpsRetryModeUpstream
{
if
pinnedAccountID
!=
nil
&&
*
pinnedAccountID
>
0
{
if
pinnedAccountID
!=
nil
&&
*
pinnedAccountID
>
0
{
pinned
=
pinnedAccountID
pinned
=
pinnedAccountID
}
else
if
errorLog
.
AccountID
!=
nil
&&
*
errorLog
.
AccountID
>
0
{
}
else
if
errorLog
.
AccountID
!=
nil
&&
*
errorLog
.
AccountID
>
0
{
pinned
=
errorLog
.
AccountID
pinned
=
errorLog
.
AccountID
}
else
{
}
else
{
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_PINNED_ACCOUNT_REQUIRED"
,
"
pinned_
account_id is required for upstream retry"
)
return
nil
,
infraerrors
.
BadRequest
(
"OPS_RETRY_PINNED_ACCOUNT_REQUIRED"
,
"account_id is required for upstream retry"
)
}
}
}
}
...
@@ -196,7 +271,7 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
...
@@ -196,7 +271,7 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
execCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
opsRetryTimeout
)
execCtx
,
cancel
:=
context
.
WithTimeout
(
ctx
,
opsRetryTimeout
)
defer
cancel
()
defer
cancel
()
execRes
:=
s
.
executeRetry
(
execCtx
,
errorLog
,
m
ode
,
pinned
)
execRes
:=
s
.
executeRetry
(
execCtx
,
errorLog
,
execM
ode
,
pinned
)
finishedAt
:=
time
.
Now
()
finishedAt
:=
time
.
Now
()
result
.
FinishedAt
=
finishedAt
result
.
FinishedAt
=
finishedAt
...
@@ -220,27 +295,40 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
...
@@ -220,27 +295,40 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
msg
:=
result
.
ErrorMessage
msg
:=
result
.
ErrorMessage
updateErrMsg
=
&
msg
updateErrMsg
=
&
msg
}
}
// Keep legacy result_request_id empty; use upstream_request_id instead.
var
resultRequestID
*
string
var
resultRequestID
*
string
if
strings
.
TrimSpace
(
result
.
UpstreamRequestID
)
!=
""
{
v
:=
result
.
UpstreamRequestID
resultRequestID
=
&
v
}
finalStatus
:=
result
.
Status
finalStatus
:=
result
.
Status
if
strings
.
TrimSpace
(
finalStatus
)
==
""
{
if
strings
.
TrimSpace
(
finalStatus
)
==
""
{
finalStatus
=
opsRetryStatusFailed
finalStatus
=
opsRetryStatusFailed
}
}
success
:=
strings
.
EqualFold
(
finalStatus
,
opsRetryStatusSucceeded
)
httpStatus
:=
result
.
HTTPStatusCode
upstreamReqID
:=
result
.
UpstreamRequestID
usedAccountID
:=
result
.
UsedAccountID
preview
:=
result
.
ResponsePreview
truncated
:=
result
.
ResponseTruncated
if
err
:=
s
.
opsRepo
.
UpdateRetryAttempt
(
updateCtx
,
&
OpsUpdateRetryAttemptInput
{
if
err
:=
s
.
opsRepo
.
UpdateRetryAttempt
(
updateCtx
,
&
OpsUpdateRetryAttemptInput
{
ID
:
attemptID
,
ID
:
attemptID
,
Status
:
finalStatus
,
Status
:
finalStatus
,
FinishedAt
:
finishedAt
,
FinishedAt
:
finishedAt
,
DurationMs
:
result
.
DurationMs
,
DurationMs
:
result
.
DurationMs
,
ResultRequestID
:
resultRequestID
,
Success
:
&
success
,
ErrorMessage
:
updateErrMsg
,
HTTPStatusCode
:
&
httpStatus
,
UpstreamRequestID
:
&
upstreamReqID
,
UsedAccountID
:
usedAccountID
,
ResponsePreview
:
&
preview
,
ResponseTruncated
:
&
truncated
,
ResultRequestID
:
resultRequestID
,
ErrorMessage
:
updateErrMsg
,
});
err
!=
nil
{
});
err
!=
nil
{
// Best-effort: retry itself already executed; do not fail the API response.
log
.
Printf
(
"[Ops] UpdateRetryAttempt failed: %v"
,
err
)
log
.
Printf
(
"[Ops] UpdateRetryAttempt failed: %v"
,
err
)
}
else
if
success
{
if
err
:=
s
.
opsRepo
.
UpdateErrorResolution
(
updateCtx
,
errorID
,
true
,
&
requestedByUserID
,
&
attemptID
,
&
finishedAt
);
err
!=
nil
{
log
.
Printf
(
"[Ops] UpdateErrorResolution failed: %v"
,
err
)
}
}
}
return
result
,
nil
return
result
,
nil
...
@@ -426,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
...
@@ -426,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
if
s
.
gatewayService
==
nil
{
if
s
.
gatewayService
==
nil
{
return
nil
,
fmt
.
Errorf
(
"gateway service not available"
)
return
nil
,
fmt
.
Errorf
(
"gateway service not available"
)
}
}
return
s
.
gatewayService
.
SelectAccountWithLoadAwareness
(
ctx
,
groupID
,
""
,
model
,
excludedIDs
)
return
s
.
gatewayService
.
SelectAccountWithLoadAwareness
(
ctx
,
groupID
,
""
,
model
,
excludedIDs
,
""
)
// 重试不使用会话限制
default
:
default
:
return
nil
,
fmt
.
Errorf
(
"unsupported retry type: %s"
,
reqType
)
return
nil
,
fmt
.
Errorf
(
"unsupported retry type: %s"
,
reqType
)
}
}
...
...
backend/internal/service/ops_scheduled_report_service.go
View file @
6901b64f
...
@@ -177,6 +177,10 @@ func (s *OpsScheduledReportService) runOnce() {
...
@@ -177,6 +177,10 @@ func (s *OpsScheduledReportService) runOnce() {
return
return
}
}
reportsTotal
:=
len
(
reports
)
reportsDue
:=
0
sentAttempts
:=
0
for
_
,
report
:=
range
reports
{
for
_
,
report
:=
range
reports
{
if
report
==
nil
||
!
report
.
Enabled
{
if
report
==
nil
||
!
report
.
Enabled
{
continue
continue
...
@@ -184,14 +188,18 @@ func (s *OpsScheduledReportService) runOnce() {
...
@@ -184,14 +188,18 @@ func (s *OpsScheduledReportService) runOnce() {
if
report
.
NextRunAt
.
After
(
now
)
{
if
report
.
NextRunAt
.
After
(
now
)
{
continue
continue
}
}
reportsDue
++
if
err
:=
s
.
runReport
(
ctx
,
report
,
now
);
err
!=
nil
{
attempts
,
err
:=
s
.
runReport
(
ctx
,
report
,
now
)
if
err
!=
nil
{
s
.
recordHeartbeatError
(
runAt
,
time
.
Since
(
startedAt
),
err
)
s
.
recordHeartbeatError
(
runAt
,
time
.
Since
(
startedAt
),
err
)
return
return
}
}
sentAttempts
+=
attempts
}
}
s
.
recordHeartbeatSuccess
(
runAt
,
time
.
Since
(
startedAt
))
result
:=
truncateString
(
fmt
.
Sprintf
(
"reports=%d due=%d send_attempts=%d"
,
reportsTotal
,
reportsDue
,
sentAttempts
),
2048
)
s
.
recordHeartbeatSuccess
(
runAt
,
time
.
Since
(
startedAt
),
result
)
}
}
type
opsScheduledReport
struct
{
type
opsScheduledReport
struct
{
...
@@ -297,9 +305,9 @@ func (s *OpsScheduledReportService) listScheduledReports(ctx context.Context, no
...
@@ -297,9 +305,9 @@ func (s *OpsScheduledReportService) listScheduledReports(ctx context.Context, no
return
out
return
out
}
}
func
(
s
*
OpsScheduledReportService
)
runReport
(
ctx
context
.
Context
,
report
*
opsScheduledReport
,
now
time
.
Time
)
error
{
func
(
s
*
OpsScheduledReportService
)
runReport
(
ctx
context
.
Context
,
report
*
opsScheduledReport
,
now
time
.
Time
)
(
int
,
error
)
{
if
s
==
nil
||
s
.
opsService
==
nil
||
s
.
emailService
==
nil
||
report
==
nil
{
if
s
==
nil
||
s
.
opsService
==
nil
||
s
.
emailService
==
nil
||
report
==
nil
{
return
nil
return
0
,
nil
}
}
if
ctx
==
nil
{
if
ctx
==
nil
{
ctx
=
context
.
Background
()
ctx
=
context
.
Background
()
...
@@ -310,11 +318,11 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
...
@@ -310,11 +318,11 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
content
,
err
:=
s
.
generateReportHTML
(
ctx
,
report
,
now
)
content
,
err
:=
s
.
generateReportHTML
(
ctx
,
report
,
now
)
if
err
!=
nil
{
if
err
!=
nil
{
return
err
return
0
,
err
}
}
if
strings
.
TrimSpace
(
content
)
==
""
{
if
strings
.
TrimSpace
(
content
)
==
""
{
// Skip sending when the report decides not to emit content (e.g., digest below min count).
// Skip sending when the report decides not to emit content (e.g., digest below min count).
return
nil
return
0
,
nil
}
}
recipients
:=
report
.
Recipients
recipients
:=
report
.
Recipients
...
@@ -325,22 +333,24 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
...
@@ -325,22 +333,24 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
}
}
}
}
if
len
(
recipients
)
==
0
{
if
len
(
recipients
)
==
0
{
return
nil
return
0
,
nil
}
}
subject
:=
fmt
.
Sprintf
(
"[Ops Report] %s"
,
strings
.
TrimSpace
(
report
.
Name
))
subject
:=
fmt
.
Sprintf
(
"[Ops Report] %s"
,
strings
.
TrimSpace
(
report
.
Name
))
attempts
:=
0
for
_
,
to
:=
range
recipients
{
for
_
,
to
:=
range
recipients
{
addr
:=
strings
.
TrimSpace
(
to
)
addr
:=
strings
.
TrimSpace
(
to
)
if
addr
==
""
{
if
addr
==
""
{
continue
continue
}
}
attempts
++
if
err
:=
s
.
emailService
.
SendEmail
(
ctx
,
addr
,
subject
,
content
);
err
!=
nil
{
if
err
:=
s
.
emailService
.
SendEmail
(
ctx
,
addr
,
subject
,
content
);
err
!=
nil
{
// Ignore per-recipient failures; continue best-effort.
// Ignore per-recipient failures; continue best-effort.
continue
continue
}
}
}
}
return
nil
return
attempts
,
nil
}
}
func
(
s
*
OpsScheduledReportService
)
generateReportHTML
(
ctx
context
.
Context
,
report
*
opsScheduledReport
,
now
time
.
Time
)
(
string
,
error
)
{
func
(
s
*
OpsScheduledReportService
)
generateReportHTML
(
ctx
context
.
Context
,
report
*
opsScheduledReport
,
now
time
.
Time
)
(
string
,
error
)
{
...
@@ -650,7 +660,7 @@ func (s *OpsScheduledReportService) setLastRunAt(ctx context.Context, reportType
...
@@ -650,7 +660,7 @@ func (s *OpsScheduledReportService) setLastRunAt(ctx context.Context, reportType
_
=
s
.
redisClient
.
Set
(
ctx
,
key
,
strconv
.
FormatInt
(
t
.
UTC
()
.
Unix
(),
10
),
14
*
24
*
time
.
Hour
)
.
Err
()
_
=
s
.
redisClient
.
Set
(
ctx
,
key
,
strconv
.
FormatInt
(
t
.
UTC
()
.
Unix
(),
10
),
14
*
24
*
time
.
Hour
)
.
Err
()
}
}
func
(
s
*
OpsScheduledReportService
)
recordHeartbeatSuccess
(
runAt
time
.
Time
,
duration
time
.
Duration
)
{
func
(
s
*
OpsScheduledReportService
)
recordHeartbeatSuccess
(
runAt
time
.
Time
,
duration
time
.
Duration
,
result
string
)
{
if
s
==
nil
||
s
.
opsService
==
nil
||
s
.
opsService
.
opsRepo
==
nil
{
if
s
==
nil
||
s
.
opsService
==
nil
||
s
.
opsService
.
opsRepo
==
nil
{
return
return
}
}
...
@@ -658,11 +668,17 @@ func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, dura
...
@@ -658,11 +668,17 @@ func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, dura
durMs
:=
duration
.
Milliseconds
()
durMs
:=
duration
.
Milliseconds
()
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
2
*
time
.
Second
)
defer
cancel
()
defer
cancel
()
msg
:=
strings
.
TrimSpace
(
result
)
if
msg
==
""
{
msg
=
"ok"
}
msg
=
truncateString
(
msg
,
2048
)
_
=
s
.
opsService
.
opsRepo
.
UpsertJobHeartbeat
(
ctx
,
&
OpsUpsertJobHeartbeatInput
{
_
=
s
.
opsService
.
opsRepo
.
UpsertJobHeartbeat
(
ctx
,
&
OpsUpsertJobHeartbeatInput
{
JobName
:
opsScheduledReportJobName
,
JobName
:
opsScheduledReportJobName
,
LastRunAt
:
&
runAt
,
LastRunAt
:
&
runAt
,
LastSuccessAt
:
&
now
,
LastSuccessAt
:
&
now
,
LastDurationMs
:
&
durMs
,
LastDurationMs
:
&
durMs
,
LastResult
:
&
msg
,
})
})
}
}
...
...
backend/internal/service/ops_service.go
View file @
6901b64f
...
@@ -208,6 +208,25 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
...
@@ -208,6 +208,25 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
out
.
Detail
=
""
out
.
Detail
=
""
}
}
out
.
UpstreamRequestBody
=
strings
.
TrimSpace
(
out
.
UpstreamRequestBody
)
if
out
.
UpstreamRequestBody
!=
""
{
// Reuse the same sanitization/trimming strategy as request body storage.
// Keep it small so it is safe to persist in ops_error_logs JSON.
sanitized
,
truncated
,
_
:=
sanitizeAndTrimRequestBody
([]
byte
(
out
.
UpstreamRequestBody
),
10
*
1024
)
if
sanitized
!=
""
{
out
.
UpstreamRequestBody
=
sanitized
if
truncated
{
out
.
Kind
=
strings
.
TrimSpace
(
out
.
Kind
)
if
out
.
Kind
==
""
{
out
.
Kind
=
"upstream"
}
out
.
Kind
=
out
.
Kind
+
":request_body_truncated"
}
}
else
{
out
.
UpstreamRequestBody
=
""
}
}
// Drop fully-empty events (can happen if only status code was known).
// Drop fully-empty events (can happen if only status code was known).
if
out
.
UpstreamStatusCode
==
0
&&
out
.
Message
==
""
&&
out
.
Detail
==
""
{
if
out
.
UpstreamStatusCode
==
0
&&
out
.
Message
==
""
&&
out
.
Detail
==
""
{
continue
continue
...
@@ -236,7 +255,13 @@ func (s *OpsService) GetErrorLogs(ctx context.Context, filter *OpsErrorLogFilter
...
@@ -236,7 +255,13 @@ func (s *OpsService) GetErrorLogs(ctx context.Context, filter *OpsErrorLogFilter
if
s
.
opsRepo
==
nil
{
if
s
.
opsRepo
==
nil
{
return
&
OpsErrorLogList
{
Errors
:
[]
*
OpsErrorLog
{},
Total
:
0
,
Page
:
1
,
PageSize
:
20
},
nil
return
&
OpsErrorLogList
{
Errors
:
[]
*
OpsErrorLog
{},
Total
:
0
,
Page
:
1
,
PageSize
:
20
},
nil
}
}
return
s
.
opsRepo
.
ListErrorLogs
(
ctx
,
filter
)
result
,
err
:=
s
.
opsRepo
.
ListErrorLogs
(
ctx
,
filter
)
if
err
!=
nil
{
log
.
Printf
(
"[Ops] GetErrorLogs failed: %v"
,
err
)
return
nil
,
err
}
return
result
,
nil
}
}
func
(
s
*
OpsService
)
GetErrorLogByID
(
ctx
context
.
Context
,
id
int64
)
(
*
OpsErrorLogDetail
,
error
)
{
func
(
s
*
OpsService
)
GetErrorLogByID
(
ctx
context
.
Context
,
id
int64
)
(
*
OpsErrorLogDetail
,
error
)
{
...
@@ -256,6 +281,46 @@ func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLo
...
@@ -256,6 +281,46 @@ func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLo
return
detail
,
nil
return
detail
,
nil
}
}
func
(
s
*
OpsService
)
ListRetryAttemptsByErrorID
(
ctx
context
.
Context
,
errorID
int64
,
limit
int
)
([]
*
OpsRetryAttempt
,
error
)
{
if
err
:=
s
.
RequireMonitoringEnabled
(
ctx
);
err
!=
nil
{
return
nil
,
err
}
if
s
.
opsRepo
==
nil
{
return
nil
,
infraerrors
.
ServiceUnavailable
(
"OPS_REPO_UNAVAILABLE"
,
"Ops repository not available"
)
}
if
errorID
<=
0
{
return
nil
,
infraerrors
.
BadRequest
(
"OPS_ERROR_INVALID_ID"
,
"invalid error id"
)
}
items
,
err
:=
s
.
opsRepo
.
ListRetryAttemptsByErrorID
(
ctx
,
errorID
,
limit
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
return
[]
*
OpsRetryAttempt
{},
nil
}
return
nil
,
infraerrors
.
InternalServer
(
"OPS_RETRY_LIST_FAILED"
,
"Failed to list retry attempts"
)
.
WithCause
(
err
)
}
return
items
,
nil
}
func
(
s
*
OpsService
)
UpdateErrorResolution
(
ctx
context
.
Context
,
errorID
int64
,
resolved
bool
,
resolvedByUserID
*
int64
,
resolvedRetryID
*
int64
)
error
{
if
err
:=
s
.
RequireMonitoringEnabled
(
ctx
);
err
!=
nil
{
return
err
}
if
s
.
opsRepo
==
nil
{
return
infraerrors
.
ServiceUnavailable
(
"OPS_REPO_UNAVAILABLE"
,
"Ops repository not available"
)
}
if
errorID
<=
0
{
return
infraerrors
.
BadRequest
(
"OPS_ERROR_INVALID_ID"
,
"invalid error id"
)
}
// Best-effort ensure the error exists
if
_
,
err
:=
s
.
opsRepo
.
GetErrorLogByID
(
ctx
,
errorID
);
err
!=
nil
{
if
errors
.
Is
(
err
,
sql
.
ErrNoRows
)
{
return
infraerrors
.
NotFound
(
"OPS_ERROR_NOT_FOUND"
,
"ops error log not found"
)
}
return
infraerrors
.
InternalServer
(
"OPS_ERROR_LOAD_FAILED"
,
"Failed to load ops error log"
)
.
WithCause
(
err
)
}
return
s
.
opsRepo
.
UpdateErrorResolution
(
ctx
,
errorID
,
resolved
,
resolvedByUserID
,
resolvedRetryID
,
nil
)
}
func
sanitizeAndTrimRequestBody
(
raw
[]
byte
,
maxBytes
int
)
(
jsonString
string
,
truncated
bool
,
bytesLen
int
)
{
func
sanitizeAndTrimRequestBody
(
raw
[]
byte
,
maxBytes
int
)
(
jsonString
string
,
truncated
bool
,
bytesLen
int
)
{
bytesLen
=
len
(
raw
)
bytesLen
=
len
(
raw
)
if
len
(
raw
)
==
0
{
if
len
(
raw
)
==
0
{
...
@@ -296,14 +361,34 @@ func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, tr
...
@@ -296,14 +361,34 @@ func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, tr
}
}
}
}
// Last resort: store a minimal placeholder (still valid JSON).
// Last resort: keep JSON shape but drop big fields.
placeholder
:=
map
[
string
]
any
{
// This avoids downstream code that expects certain top-level keys from crashing.
"request_body_truncated"
:
true
,
if
root
,
ok
:=
decoded
.
(
map
[
string
]
any
);
ok
{
}
placeholder
:=
shallowCopyMap
(
root
)
if
model
:=
extractString
(
decoded
,
"model"
);
model
!=
""
{
placeholder
[
"request_body_truncated"
]
=
true
placeholder
[
"model"
]
=
model
// Replace potentially huge arrays/strings, but keep the keys present.
for
_
,
k
:=
range
[]
string
{
"messages"
,
"contents"
,
"input"
,
"prompt"
}
{
if
_
,
exists
:=
placeholder
[
k
];
exists
{
placeholder
[
k
]
=
[]
any
{}
}
}
for
_
,
k
:=
range
[]
string
{
"text"
}
{
if
_
,
exists
:=
placeholder
[
k
];
exists
{
placeholder
[
k
]
=
""
}
}
encoded4
,
err4
:=
json
.
Marshal
(
placeholder
)
if
err4
==
nil
{
if
len
(
encoded4
)
<=
maxBytes
{
return
string
(
encoded4
),
true
,
bytesLen
}
}
}
}
encoded4
,
err4
:=
json
.
Marshal
(
placeholder
)
// Final fallback: minimal valid JSON.
encoded4
,
err4
:=
json
.
Marshal
(
map
[
string
]
any
{
"request_body_truncated"
:
true
})
if
err4
!=
nil
{
if
err4
!=
nil
{
return
""
,
true
,
bytesLen
return
""
,
true
,
bytesLen
}
}
...
@@ -526,12 +611,3 @@ func sanitizeErrorBodyForStorage(raw string, maxBytes int) (sanitized string, tr
...
@@ -526,12 +611,3 @@ func sanitizeErrorBodyForStorage(raw string, maxBytes int) (sanitized string, tr
}
}
return
raw
,
false
return
raw
,
false
}
}
func
extractString
(
v
any
,
key
string
)
string
{
root
,
ok
:=
v
.
(
map
[
string
]
any
)
if
!
ok
{
return
""
}
s
,
_
:=
root
[
key
]
.
(
string
)
return
strings
.
TrimSpace
(
s
)
}
backend/internal/service/ops_settings.go
View file @
6901b64f
...
@@ -368,9 +368,11 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
...
@@ -368,9 +368,11 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
Aggregation
:
OpsAggregationSettings
{
Aggregation
:
OpsAggregationSettings
{
AggregationEnabled
:
false
,
AggregationEnabled
:
false
,
},
},
IgnoreCountTokensErrors
:
false
,
IgnoreCountTokensErrors
:
false
,
AutoRefreshEnabled
:
false
,
IgnoreContextCanceled
:
true
,
// Default to true - client disconnects are not errors
AutoRefreshIntervalSec
:
30
,
IgnoreNoAvailableAccounts
:
false
,
// Default to false - this is a real routing issue
AutoRefreshEnabled
:
false
,
AutoRefreshIntervalSec
:
30
,
}
}
}
}
...
@@ -482,13 +484,11 @@ const SettingKeyOpsMetricThresholds = "ops_metric_thresholds"
...
@@ -482,13 +484,11 @@ const SettingKeyOpsMetricThresholds = "ops_metric_thresholds"
func
defaultOpsMetricThresholds
()
*
OpsMetricThresholds
{
func
defaultOpsMetricThresholds
()
*
OpsMetricThresholds
{
slaMin
:=
99.5
slaMin
:=
99.5
latencyMax
:=
2000.0
ttftMax
:=
500.0
ttftMax
:=
500.0
reqErrMax
:=
5.0
reqErrMax
:=
5.0
upstreamErrMax
:=
5.0
upstreamErrMax
:=
5.0
return
&
OpsMetricThresholds
{
return
&
OpsMetricThresholds
{
SLAPercentMin
:
&
slaMin
,
SLAPercentMin
:
&
slaMin
,
LatencyP99MsMax
:
&
latencyMax
,
TTFTp99MsMax
:
&
ttftMax
,
TTFTp99MsMax
:
&
ttftMax
,
RequestErrorRatePercentMax
:
&
reqErrMax
,
RequestErrorRatePercentMax
:
&
reqErrMax
,
UpstreamErrorRatePercentMax
:
&
upstreamErrMax
,
UpstreamErrorRatePercentMax
:
&
upstreamErrMax
,
...
@@ -538,9 +538,6 @@ func (s *OpsService) UpdateMetricThresholds(ctx context.Context, cfg *OpsMetricT
...
@@ -538,9 +538,6 @@ func (s *OpsService) UpdateMetricThresholds(ctx context.Context, cfg *OpsMetricT
if
cfg
.
SLAPercentMin
!=
nil
&&
(
*
cfg
.
SLAPercentMin
<
0
||
*
cfg
.
SLAPercentMin
>
100
)
{
if
cfg
.
SLAPercentMin
!=
nil
&&
(
*
cfg
.
SLAPercentMin
<
0
||
*
cfg
.
SLAPercentMin
>
100
)
{
return
nil
,
errors
.
New
(
"sla_percent_min must be between 0 and 100"
)
return
nil
,
errors
.
New
(
"sla_percent_min must be between 0 and 100"
)
}
}
if
cfg
.
LatencyP99MsMax
!=
nil
&&
*
cfg
.
LatencyP99MsMax
<
0
{
return
nil
,
errors
.
New
(
"latency_p99_ms_max must be >= 0"
)
}
if
cfg
.
TTFTp99MsMax
!=
nil
&&
*
cfg
.
TTFTp99MsMax
<
0
{
if
cfg
.
TTFTp99MsMax
!=
nil
&&
*
cfg
.
TTFTp99MsMax
<
0
{
return
nil
,
errors
.
New
(
"ttft_p99_ms_max must be >= 0"
)
return
nil
,
errors
.
New
(
"ttft_p99_ms_max must be >= 0"
)
}
}
...
...
backend/internal/service/ops_settings_models.go
View file @
6901b64f
...
@@ -63,7 +63,6 @@ type OpsAlertSilencingSettings struct {
...
@@ -63,7 +63,6 @@ type OpsAlertSilencingSettings struct {
type
OpsMetricThresholds
struct
{
type
OpsMetricThresholds
struct
{
SLAPercentMin
*
float64
`json:"sla_percent_min,omitempty"`
// SLA低于此值变红
SLAPercentMin
*
float64
`json:"sla_percent_min,omitempty"`
// SLA低于此值变红
LatencyP99MsMax
*
float64
`json:"latency_p99_ms_max,omitempty"`
// 延迟P99高于此值变红
TTFTp99MsMax
*
float64
`json:"ttft_p99_ms_max,omitempty"`
// TTFT P99高于此值变红
TTFTp99MsMax
*
float64
`json:"ttft_p99_ms_max,omitempty"`
// TTFT P99高于此值变红
RequestErrorRatePercentMax
*
float64
`json:"request_error_rate_percent_max,omitempty"`
// 请求错误率高于此值变红
RequestErrorRatePercentMax
*
float64
`json:"request_error_rate_percent_max,omitempty"`
// 请求错误率高于此值变红
UpstreamErrorRatePercentMax
*
float64
`json:"upstream_error_rate_percent_max,omitempty"`
// 上游错误率高于此值变红
UpstreamErrorRatePercentMax
*
float64
`json:"upstream_error_rate_percent_max,omitempty"`
// 上游错误率高于此值变红
...
@@ -79,11 +78,13 @@ type OpsAlertRuntimeSettings struct {
...
@@ -79,11 +78,13 @@ type OpsAlertRuntimeSettings struct {
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
type
OpsAdvancedSettings
struct
{
type
OpsAdvancedSettings
struct
{
DataRetention
OpsDataRetentionSettings
`json:"data_retention"`
DataRetention
OpsDataRetentionSettings
`json:"data_retention"`
Aggregation
OpsAggregationSettings
`json:"aggregation"`
Aggregation
OpsAggregationSettings
`json:"aggregation"`
IgnoreCountTokensErrors
bool
`json:"ignore_count_tokens_errors"`
IgnoreCountTokensErrors
bool
`json:"ignore_count_tokens_errors"`
AutoRefreshEnabled
bool
`json:"auto_refresh_enabled"`
IgnoreContextCanceled
bool
`json:"ignore_context_canceled"`
AutoRefreshIntervalSec
int
`json:"auto_refresh_interval_seconds"`
IgnoreNoAvailableAccounts
bool
`json:"ignore_no_available_accounts"`
AutoRefreshEnabled
bool
`json:"auto_refresh_enabled"`
AutoRefreshIntervalSec
int
`json:"auto_refresh_interval_seconds"`
}
}
type
OpsDataRetentionSettings
struct
{
type
OpsDataRetentionSettings
struct
{
...
...
backend/internal/service/ops_upstream_context.go
View file @
6901b64f
...
@@ -15,6 +15,11 @@ const (
...
@@ -15,6 +15,11 @@ const (
OpsUpstreamErrorMessageKey
=
"ops_upstream_error_message"
OpsUpstreamErrorMessageKey
=
"ops_upstream_error_message"
OpsUpstreamErrorDetailKey
=
"ops_upstream_error_detail"
OpsUpstreamErrorDetailKey
=
"ops_upstream_error_detail"
OpsUpstreamErrorsKey
=
"ops_upstream_errors"
OpsUpstreamErrorsKey
=
"ops_upstream_errors"
// Best-effort capture of the current upstream request body so ops can
// retry the specific upstream attempt (not just the client request).
// This value is sanitized+trimmed before being persisted.
OpsUpstreamRequestBodyKey
=
"ops_upstream_request_body"
)
)
func
setOpsUpstreamError
(
c
*
gin
.
Context
,
upstreamStatusCode
int
,
upstreamMessage
,
upstreamDetail
string
)
{
func
setOpsUpstreamError
(
c
*
gin
.
Context
,
upstreamStatusCode
int
,
upstreamMessage
,
upstreamDetail
string
)
{
...
@@ -38,13 +43,21 @@ type OpsUpstreamErrorEvent struct {
...
@@ -38,13 +43,21 @@ type OpsUpstreamErrorEvent struct {
AtUnixMs
int64
`json:"at_unix_ms,omitempty"`
AtUnixMs
int64
`json:"at_unix_ms,omitempty"`
// Context
// Context
Platform
string
`json:"platform,omitempty"`
Platform
string
`json:"platform,omitempty"`
AccountID
int64
`json:"account_id,omitempty"`
AccountID
int64
`json:"account_id,omitempty"`
AccountName
string
`json:"account_name,omitempty"`
// Outcome
// Outcome
UpstreamStatusCode
int
`json:"upstream_status_code,omitempty"`
UpstreamStatusCode
int
`json:"upstream_status_code,omitempty"`
UpstreamRequestID
string
`json:"upstream_request_id,omitempty"`
UpstreamRequestID
string
`json:"upstream_request_id,omitempty"`
// Best-effort upstream request capture (sanitized+trimmed).
// Required for retrying a specific upstream attempt.
UpstreamRequestBody
string
`json:"upstream_request_body,omitempty"`
// Best-effort upstream response capture (sanitized+trimmed).
UpstreamResponseBody
string
`json:"upstream_response_body,omitempty"`
// Kind: http_error | request_error | retry_exhausted | failover
// Kind: http_error | request_error | retry_exhausted | failover
Kind
string
`json:"kind,omitempty"`
Kind
string
`json:"kind,omitempty"`
...
@@ -61,6 +74,8 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
...
@@ -61,6 +74,8 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
}
}
ev
.
Platform
=
strings
.
TrimSpace
(
ev
.
Platform
)
ev
.
Platform
=
strings
.
TrimSpace
(
ev
.
Platform
)
ev
.
UpstreamRequestID
=
strings
.
TrimSpace
(
ev
.
UpstreamRequestID
)
ev
.
UpstreamRequestID
=
strings
.
TrimSpace
(
ev
.
UpstreamRequestID
)
ev
.
UpstreamRequestBody
=
strings
.
TrimSpace
(
ev
.
UpstreamRequestBody
)
ev
.
UpstreamResponseBody
=
strings
.
TrimSpace
(
ev
.
UpstreamResponseBody
)
ev
.
Kind
=
strings
.
TrimSpace
(
ev
.
Kind
)
ev
.
Kind
=
strings
.
TrimSpace
(
ev
.
Kind
)
ev
.
Message
=
strings
.
TrimSpace
(
ev
.
Message
)
ev
.
Message
=
strings
.
TrimSpace
(
ev
.
Message
)
ev
.
Detail
=
strings
.
TrimSpace
(
ev
.
Detail
)
ev
.
Detail
=
strings
.
TrimSpace
(
ev
.
Detail
)
...
@@ -68,6 +83,16 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
...
@@ -68,6 +83,16 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
ev
.
Message
=
sanitizeUpstreamErrorMessage
(
ev
.
Message
)
ev
.
Message
=
sanitizeUpstreamErrorMessage
(
ev
.
Message
)
}
}
// If the caller didn't explicitly pass upstream request body but the gateway
// stored it on the context, attach it so ops can retry this specific attempt.
if
ev
.
UpstreamRequestBody
==
""
{
if
v
,
ok
:=
c
.
Get
(
OpsUpstreamRequestBodyKey
);
ok
{
if
s
,
ok
:=
v
.
(
string
);
ok
{
ev
.
UpstreamRequestBody
=
strings
.
TrimSpace
(
s
)
}
}
}
var
existing
[]
*
OpsUpstreamErrorEvent
var
existing
[]
*
OpsUpstreamErrorEvent
if
v
,
ok
:=
c
.
Get
(
OpsUpstreamErrorsKey
);
ok
{
if
v
,
ok
:=
c
.
Get
(
OpsUpstreamErrorsKey
);
ok
{
if
arr
,
ok
:=
v
.
([]
*
OpsUpstreamErrorEvent
);
ok
{
if
arr
,
ok
:=
v
.
([]
*
OpsUpstreamErrorEvent
);
ok
{
...
@@ -92,3 +117,15 @@ func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {
...
@@ -92,3 +117,15 @@ func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {
s
:=
string
(
raw
)
s
:=
string
(
raw
)
return
&
s
return
&
s
}
}
func
ParseOpsUpstreamErrors
(
raw
string
)
([]
*
OpsUpstreamErrorEvent
,
error
)
{
raw
=
strings
.
TrimSpace
(
raw
)
if
raw
==
""
{
return
[]
*
OpsUpstreamErrorEvent
{},
nil
}
var
out
[]
*
OpsUpstreamErrorEvent
if
err
:=
json
.
Unmarshal
([]
byte
(
raw
),
&
out
);
err
!=
nil
{
return
nil
,
err
}
return
out
,
nil
}
backend/internal/service/proxy.go
View file @
6901b64f
...
@@ -31,5 +31,21 @@ func (p *Proxy) URL() string {
...
@@ -31,5 +31,21 @@ func (p *Proxy) URL() string {
type
ProxyWithAccountCount
struct
{
type
ProxyWithAccountCount
struct
{
Proxy
Proxy
AccountCount
int64
AccountCount
int64
LatencyMs
*
int64
LatencyStatus
string
LatencyMessage
string
IPAddress
string
Country
string
CountryCode
string
Region
string
City
string
}
type
ProxyAccountSummary
struct
{
ID
int64
Name
string
Platform
string
Type
string
Notes
*
string
}
}
backend/internal/service/proxy_latency_cache.go
0 → 100644
View file @
6901b64f
package
service
import
(
"context"
"time"
)
type
ProxyLatencyInfo
struct
{
Success
bool
`json:"success"`
LatencyMs
*
int64
`json:"latency_ms,omitempty"`
Message
string
`json:"message,omitempty"`
IPAddress
string
`json:"ip_address,omitempty"`
Country
string
`json:"country,omitempty"`
CountryCode
string
`json:"country_code,omitempty"`
Region
string
`json:"region,omitempty"`
City
string
`json:"city,omitempty"`
UpdatedAt
time
.
Time
`json:"updated_at"`
}
type
ProxyLatencyCache
interface
{
GetProxyLatencies
(
ctx
context
.
Context
,
proxyIDs
[]
int64
)
(
map
[
int64
]
*
ProxyLatencyInfo
,
error
)
SetProxyLatency
(
ctx
context
.
Context
,
proxyID
int64
,
info
*
ProxyLatencyInfo
)
error
}
backend/internal/service/proxy_service.go
View file @
6901b64f
...
@@ -10,6 +10,7 @@ import (
...
@@ -10,6 +10,7 @@ import (
var
(
var
(
ErrProxyNotFound
=
infraerrors
.
NotFound
(
"PROXY_NOT_FOUND"
,
"proxy not found"
)
ErrProxyNotFound
=
infraerrors
.
NotFound
(
"PROXY_NOT_FOUND"
,
"proxy not found"
)
ErrProxyInUse
=
infraerrors
.
Conflict
(
"PROXY_IN_USE"
,
"proxy is in use by accounts"
)
)
)
type
ProxyRepository
interface
{
type
ProxyRepository
interface
{
...
@@ -26,6 +27,7 @@ type ProxyRepository interface {
...
@@ -26,6 +27,7 @@ type ProxyRepository interface {
ExistsByHostPortAuth
(
ctx
context
.
Context
,
host
string
,
port
int
,
username
,
password
string
)
(
bool
,
error
)
ExistsByHostPortAuth
(
ctx
context
.
Context
,
host
string
,
port
int
,
username
,
password
string
)
(
bool
,
error
)
CountAccountsByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
(
int64
,
error
)
CountAccountsByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
(
int64
,
error
)
ListAccountSummariesByProxyID
(
ctx
context
.
Context
,
proxyID
int64
)
([]
ProxyAccountSummary
,
error
)
}
}
// CreateProxyRequest 创建代理请求
// CreateProxyRequest 创建代理请求
...
...
backend/internal/service/ratelimit_service.go
View file @
6901b64f
...
@@ -3,7 +3,7 @@ package service
...
@@ -3,7 +3,7 @@ package service
import
(
import
(
"context"
"context"
"encoding/json"
"encoding/json"
"log"
"log
/slog
"
"net/http"
"net/http"
"strconv"
"strconv"
"strings"
"strings"
...
@@ -15,15 +15,16 @@ import (
...
@@ -15,15 +15,16 @@ import (
// RateLimitService 处理限流和过载状态管理
// RateLimitService 处理限流和过载状态管理
type
RateLimitService
struct
{
type
RateLimitService
struct
{
accountRepo
AccountRepository
accountRepo
AccountRepository
usageRepo
UsageLogRepository
usageRepo
UsageLogRepository
cfg
*
config
.
Config
cfg
*
config
.
Config
geminiQuotaService
*
GeminiQuotaService
geminiQuotaService
*
GeminiQuotaService
tempUnschedCache
TempUnschedCache
tempUnschedCache
TempUnschedCache
timeoutCounterCache
TimeoutCounterCache
timeoutCounterCache
TimeoutCounterCache
settingService
*
SettingService
settingService
*
SettingService
usageCacheMu
sync
.
RWMutex
tokenCacheInvalidator
TokenCacheInvalidator
usageCache
map
[
int64
]
*
geminiUsageCacheEntry
usageCacheMu
sync
.
RWMutex
usageCache
map
[
int64
]
*
geminiUsageCacheEntry
}
}
type
geminiUsageCacheEntry
struct
{
type
geminiUsageCacheEntry
struct
{
...
@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) {
...
@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s
.
settingService
=
settingService
s
.
settingService
=
settingService
}
}
// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖)
func
(
s
*
RateLimitService
)
SetTokenCacheInvalidator
(
invalidator
TokenCacheInvalidator
)
{
s
.
tokenCacheInvalidator
=
invalidator
}
// HandleUpstreamError 处理上游错误响应,标记账号状态
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
// 返回是否应该停止该账号的调度
func
(
s
*
RateLimitService
)
HandleUpstreamError
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
responseBody
[]
byte
)
(
shouldDisable
bool
)
{
func
(
s
*
RateLimitService
)
HandleUpstreamError
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
headers
http
.
Header
,
responseBody
[]
byte
)
(
shouldDisable
bool
)
{
...
@@ -63,11 +69,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -63,11 +69,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
customErrorCodesEnabled
:=
account
.
IsCustomErrorCodesEnabled
()
customErrorCodesEnabled
:=
account
.
IsCustomErrorCodesEnabled
()
if
!
account
.
ShouldHandleErrorCode
(
statusCode
)
{
if
!
account
.
ShouldHandleErrorCode
(
statusCode
)
{
log
.
Printf
(
"
A
ccount
%d: error %d skipped (not in custom error codes)"
,
account
.
ID
,
statusCode
)
s
log
.
Info
(
"
a
ccount
_error_code_skipped"
,
"account_id"
,
account
.
ID
,
"status_code"
,
statusCode
)
return
false
return
false
}
}
tempMatched
:=
s
.
tryTempUnschedulable
(
ctx
,
account
,
statusCode
,
responseBody
)
tempMatched
:=
false
if
statusCode
!=
401
{
tempMatched
=
s
.
tryTempUnschedulable
(
ctx
,
account
,
statusCode
,
responseBody
)
}
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
responseBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
responseBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
if
upstreamMsg
!=
""
{
if
upstreamMsg
!=
""
{
...
@@ -76,7 +85,25 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -76,7 +85,25 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
switch
statusCode
{
switch
statusCode
{
case
401
:
case
401
:
// 认证失败:停止调度,记录错误
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
if
account
.
Type
==
AccountTypeOAuth
{
// 1. 失效缓存
if
s
.
tokenCacheInvalidator
!=
nil
{
if
err
:=
s
.
tokenCacheInvalidator
.
InvalidateToken
(
ctx
,
account
);
err
!=
nil
{
slog
.
Warn
(
"oauth_401_invalidate_cache_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
}
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
if
account
.
Credentials
==
nil
{
account
.
Credentials
=
make
(
map
[
string
]
any
)
}
account
.
Credentials
[
"expires_at"
]
=
time
.
Now
()
.
Format
(
time
.
RFC3339
)
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
slog
.
Warn
(
"oauth_401_force_refresh_update_failed"
,
"account_id"
,
account
.
ID
,
"error"
,
err
)
}
else
{
slog
.
Info
(
"oauth_401_force_refresh_set"
,
"account_id"
,
account
.
ID
,
"platform"
,
account
.
Platform
)
}
}
msg
:=
"Authentication failed (401): invalid or expired credentials"
msg
:=
"Authentication failed (401): invalid or expired credentials"
if
upstreamMsg
!=
""
{
if
upstreamMsg
!=
""
{
msg
=
"Authentication failed (401): "
+
upstreamMsg
msg
=
"Authentication failed (401): "
+
upstreamMsg
...
@@ -100,7 +127,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -100,7 +127,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
s
.
handleAuthError
(
ctx
,
account
,
msg
)
s
.
handleAuthError
(
ctx
,
account
,
msg
)
shouldDisable
=
true
shouldDisable
=
true
case
429
:
case
429
:
s
.
handle429
(
ctx
,
account
,
headers
)
s
.
handle429
(
ctx
,
account
,
headers
,
responseBody
)
shouldDisable
=
false
shouldDisable
=
false
case
529
:
case
529
:
s
.
handle529
(
ctx
,
account
)
s
.
handle529
(
ctx
,
account
)
...
@@ -116,7 +143,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
...
@@ -116,7 +143,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
shouldDisable
=
true
shouldDisable
=
true
}
else
if
statusCode
>=
500
{
}
else
if
statusCode
>=
500
{
// 未启用自定义错误码时:仅记录5xx错误
// 未启用自定义错误码时:仅记录5xx错误
log
.
Printf
(
"
A
ccount
%d received
upstream
error
%
d"
,
account
.
ID
,
statusCode
)
s
log
.
Warn
(
"
a
ccount
_
upstream
_
error
"
,
"account_i
d"
,
account
.
ID
,
"status_code"
,
statusCode
)
shouldDisable
=
false
shouldDisable
=
false
}
}
}
}
...
@@ -163,7 +190,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
...
@@ -163,7 +190,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start
:=
geminiDailyWindowStart
(
now
)
start
:=
geminiDailyWindowStart
(
now
)
totals
,
ok
:=
s
.
getGeminiUsageTotals
(
account
.
ID
,
start
,
now
)
totals
,
ok
:=
s
.
getGeminiUsageTotals
(
account
.
ID
,
start
,
now
)
if
!
ok
{
if
!
ok
{
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
return
true
,
err
return
true
,
err
}
}
...
@@ -188,7 +215,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
...
@@ -188,7 +215,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
// NOTE:
// NOTE:
// - This is a local precheck to reduce upstream 429s.
// - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
log
.
Printf
(
"[G
emini
P
re
C
heck
] Account %d reached
daily
quota
(%d/%d), skip until %v
"
,
account
.
ID
,
used
,
limit
,
resetAt
)
s
log
.
Info
(
"g
emini
_p
re
c
heck
_
daily
_
quota
_reached"
,
"account_id
"
,
account
.
ID
,
"
used
"
,
used
,
"limit"
,
limit
,
"reset_at"
,
resetAt
)
return
false
,
nil
return
false
,
nil
}
}
}
}
...
@@ -210,7 +237,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
...
@@ -210,7 +237,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if
limit
>
0
{
if
limit
>
0
{
start
:=
now
.
Truncate
(
time
.
Minute
)
start
:=
now
.
Truncate
(
time
.
Minute
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
)
stats
,
err
:=
s
.
usageRepo
.
GetModelStatsWithFilters
(
ctx
,
start
,
now
,
0
,
0
,
account
.
ID
,
0
,
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
return
true
,
err
return
true
,
err
}
}
...
@@ -231,7 +258,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
...
@@ -231,7 +258,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if
used
>=
limit
{
if
used
>=
limit
{
resetAt
:=
start
.
Add
(
time
.
Minute
)
resetAt
:=
start
.
Add
(
time
.
Minute
)
// Do not persist "rate limited" status from local precheck. See note above.
// Do not persist "rate limited" status from local precheck. See note above.
log
.
Printf
(
"[G
emini
P
re
C
heck
] Account %d reached
minute
quota
(%d/%d), skip until %v
"
,
account
.
ID
,
used
,
limit
,
resetAt
)
s
log
.
Info
(
"g
emini
_p
re
c
heck
_
minute
_
quota
_reached"
,
"account_id
"
,
account
.
ID
,
"
used
"
,
used
,
"limit"
,
limit
,
"reset_at"
,
resetAt
)
return
false
,
nil
return
false
,
nil
}
}
}
}
...
@@ -288,32 +315,40 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
...
@@ -288,32 +315,40 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
// handleAuthError 处理认证类错误(401/403),停止账号调度
// handleAuthError 处理认证类错误(401/403),停止账号调度
func
(
s
*
RateLimitService
)
handleAuthError
(
ctx
context
.
Context
,
account
*
Account
,
errorMsg
string
)
{
func
(
s
*
RateLimitService
)
handleAuthError
(
ctx
context
.
Context
,
account
*
Account
,
errorMsg
string
)
{
if
err
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
errorMsg
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
errorMsg
);
err
!=
nil
{
log
.
Printf
(
"SetE
rror
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"account_set_e
rror
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
return
return
}
}
log
.
Printf
(
"
A
ccount
%d
disabled
due to
auth
error
: %s
"
,
account
.
ID
,
errorMsg
)
s
log
.
Warn
(
"
a
ccount
_
disabled
_
auth
_
error
"
,
"account_id
"
,
account
.
ID
,
"error"
,
errorMsg
)
}
}
// handleCustomErrorCode 处理自定义错误码,停止账号调度
// handleCustomErrorCode 处理自定义错误码,停止账号调度
func
(
s
*
RateLimitService
)
handleCustomErrorCode
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
errorMsg
string
)
{
func
(
s
*
RateLimitService
)
handleCustomErrorCode
(
ctx
context
.
Context
,
account
*
Account
,
statusCode
int
,
errorMsg
string
)
{
msg
:=
"Custom error code "
+
strconv
.
Itoa
(
statusCode
)
+
": "
+
errorMsg
msg
:=
"Custom error code "
+
strconv
.
Itoa
(
statusCode
)
+
": "
+
errorMsg
if
err
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
msg
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
msg
);
err
!=
nil
{
log
.
Printf
(
"SetE
rror
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"account_set_e
rror
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"status_code"
,
statusCode
,
"error"
,
err
)
return
return
}
}
log
.
Printf
(
"
A
ccount
%d
disabled
due to
custom
error
code %d: %s
"
,
account
.
ID
,
statusCode
,
errorMsg
)
s
log
.
Warn
(
"
a
ccount
_
disabled
_
custom
_
error
"
,
"account_id
"
,
account
.
ID
,
"status_code"
,
statusCode
,
"error"
,
errorMsg
)
}
}
// handle429 处理429限流错误
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
// 解析响应头获取重置时间,标记账号为限流状态
func
(
s
*
RateLimitService
)
handle429
(
ctx
context
.
Context
,
account
*
Account
,
headers
http
.
Header
)
{
func
(
s
*
RateLimitService
)
handle429
(
ctx
context
.
Context
,
account
*
Account
,
headers
http
.
Header
,
responseBody
[]
byte
)
{
// 解析重置时间戳
// 解析重置时间戳
resetTimestamp
:=
headers
.
Get
(
"anthropic-ratelimit-unified-reset"
)
resetTimestamp
:=
headers
.
Get
(
"anthropic-ratelimit-unified-reset"
)
if
resetTimestamp
==
""
{
if
resetTimestamp
==
""
{
// 没有重置时间,使用默认5分钟
// 没有重置时间,使用默认5分钟
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
if
s
.
shouldScopeClaudeSonnetRateLimit
(
account
,
responseBody
)
{
if
err
:=
s
.
accountRepo
.
SetModelRateLimit
(
ctx
,
account
.
ID
,
modelRateLimitScopeClaudeSonnet
,
resetAt
);
err
!=
nil
{
slog
.
Warn
(
"model_rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"scope"
,
modelRateLimitScopeClaudeSonnet
,
"error"
,
err
)
}
else
{
slog
.
Info
(
"account_model_rate_limited"
,
"account_id"
,
account
.
ID
,
"scope"
,
modelRateLimitScopeClaudeSonnet
,
"reset_at"
,
resetAt
)
}
return
}
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
log
.
Printf
(
"SetR
ate
L
imit
ed
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"r
ate
_l
imit
_set_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
return
return
}
}
...
@@ -321,19 +356,36 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
...
@@ -321,19 +356,36 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 解析Unix时间戳
// 解析Unix时间戳
ts
,
err
:=
strconv
.
ParseInt
(
resetTimestamp
,
10
,
64
)
ts
,
err
:=
strconv
.
ParseInt
(
resetTimestamp
,
10
,
64
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"Parse
reset
timestamp
failed: %v
"
,
err
)
s
log
.
Warn
(
"rate_limit_reset_parse_failed"
,
"
reset
_
timestamp
"
,
resetTimestamp
,
"error
"
,
err
)
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
resetAt
:=
time
.
Now
()
.
Add
(
5
*
time
.
Minute
)
if
s
.
shouldScopeClaudeSonnetRateLimit
(
account
,
responseBody
)
{
if
err
:=
s
.
accountRepo
.
SetModelRateLimit
(
ctx
,
account
.
ID
,
modelRateLimitScopeClaudeSonnet
,
resetAt
);
err
!=
nil
{
slog
.
Warn
(
"model_rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"scope"
,
modelRateLimitScopeClaudeSonnet
,
"error"
,
err
)
}
else
{
slog
.
Info
(
"account_model_rate_limited"
,
"account_id"
,
account
.
ID
,
"scope"
,
modelRateLimitScopeClaudeSonnet
,
"reset_at"
,
resetAt
)
}
return
}
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
log
.
Printf
(
"SetR
ate
L
imit
ed
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"r
ate
_l
imit
_set_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
return
return
}
}
resetAt
:=
time
.
Unix
(
ts
,
0
)
resetAt
:=
time
.
Unix
(
ts
,
0
)
if
s
.
shouldScopeClaudeSonnetRateLimit
(
account
,
responseBody
)
{
if
err
:=
s
.
accountRepo
.
SetModelRateLimit
(
ctx
,
account
.
ID
,
modelRateLimitScopeClaudeSonnet
,
resetAt
);
err
!=
nil
{
slog
.
Warn
(
"model_rate_limit_set_failed"
,
"account_id"
,
account
.
ID
,
"scope"
,
modelRateLimitScopeClaudeSonnet
,
"error"
,
err
)
return
}
slog
.
Info
(
"account_model_rate_limited"
,
"account_id"
,
account
.
ID
,
"scope"
,
modelRateLimitScopeClaudeSonnet
,
"reset_at"
,
resetAt
)
return
}
// 标记限流状态
// 标记限流状态
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetRateLimited
(
ctx
,
account
.
ID
,
resetAt
);
err
!=
nil
{
log
.
Printf
(
"SetR
ate
L
imit
ed
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"r
ate
_l
imit
_set_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
return
return
}
}
...
@@ -341,10 +393,21 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
...
@@ -341,10 +393,21 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
windowEnd
:=
resetAt
windowEnd
:=
resetAt
windowStart
:=
resetAt
.
Add
(
-
5
*
time
.
Hour
)
windowStart
:=
resetAt
.
Add
(
-
5
*
time
.
Hour
)
if
err
:=
s
.
accountRepo
.
UpdateSessionWindow
(
ctx
,
account
.
ID
,
&
windowStart
,
&
windowEnd
,
"rejected"
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
UpdateSessionWindow
(
ctx
,
account
.
ID
,
&
windowStart
,
&
windowEnd
,
"rejected"
);
err
!=
nil
{
log
.
Printf
(
"U
pdate
S
ession
W
indow
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"rate_limit_u
pdate
_s
ession
_w
indow
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
log
.
Printf
(
"Account %d rate limited until %v"
,
account
.
ID
,
resetAt
)
slog
.
Info
(
"account_rate_limited"
,
"account_id"
,
account
.
ID
,
"reset_at"
,
resetAt
)
}
func
(
s
*
RateLimitService
)
shouldScopeClaudeSonnetRateLimit
(
account
*
Account
,
responseBody
[]
byte
)
bool
{
if
account
==
nil
||
account
.
Platform
!=
PlatformAnthropic
{
return
false
}
msg
:=
strings
.
ToLower
(
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
responseBody
)))
if
msg
==
""
{
return
false
}
return
strings
.
Contains
(
msg
,
"sonnet"
)
}
}
// handle529 处理529过载错误
// handle529 处理529过载错误
...
@@ -357,11 +420,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
...
@@ -357,11 +420,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
until
:=
time
.
Now
()
.
Add
(
time
.
Duration
(
cooldownMinutes
)
*
time
.
Minute
)
until
:=
time
.
Now
()
.
Add
(
time
.
Duration
(
cooldownMinutes
)
*
time
.
Minute
)
if
err
:=
s
.
accountRepo
.
SetOverloaded
(
ctx
,
account
.
ID
,
until
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetOverloaded
(
ctx
,
account
.
ID
,
until
);
err
!=
nil
{
log
.
Printf
(
"SetO
verload
ed
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"o
verload
_set_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
return
return
}
}
log
.
Printf
(
"
A
ccount
%d
overloaded
until %v
"
,
account
.
ID
,
until
)
s
log
.
Info
(
"
a
ccount
_
overloaded
"
,
"account_id
"
,
account
.
ID
,
"until"
,
until
)
}
}
// UpdateSessionWindow 从成功响应更新5h窗口状态
// UpdateSessionWindow 从成功响应更新5h窗口状态
...
@@ -384,17 +447,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
...
@@ -384,17 +447,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
end
:=
start
.
Add
(
5
*
time
.
Hour
)
end
:=
start
.
Add
(
5
*
time
.
Hour
)
windowStart
=
&
start
windowStart
=
&
start
windowEnd
=
&
end
windowEnd
=
&
end
log
.
Printf
(
"
A
ccount
%d: initializing 5h window from %v to %v (status: %s)
"
,
account
.
ID
,
start
,
end
,
status
)
s
log
.
Info
(
"
a
ccount
_session_window_initialized"
,
"account_id
"
,
account
.
ID
,
"window_
start
"
,
start
,
"window_end"
,
end
,
"status"
,
status
)
}
}
if
err
:=
s
.
accountRepo
.
UpdateSessionWindow
(
ctx
,
account
.
ID
,
windowStart
,
windowEnd
,
status
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
UpdateSessionWindow
(
ctx
,
account
.
ID
,
windowStart
,
windowEnd
,
status
);
err
!=
nil
{
log
.
Printf
(
"UpdateS
ession
W
indow
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
ession
_w
indow
_update_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
if
status
==
"allowed"
&&
account
.
IsRateLimited
()
{
if
status
==
"allowed"
&&
account
.
IsRateLimited
()
{
if
err
:=
s
.
ClearRateLimit
(
ctx
,
account
.
ID
);
err
!=
nil
{
if
err
:=
s
.
ClearRateLimit
(
ctx
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"ClearR
ate
L
imit
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"r
ate
_l
imit
_clear_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
}
}
...
@@ -404,7 +467,10 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
...
@@ -404,7 +467,10 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
if
err
:=
s
.
accountRepo
.
ClearRateLimit
(
ctx
,
accountID
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
ClearRateLimit
(
ctx
,
accountID
);
err
!=
nil
{
return
err
return
err
}
}
return
s
.
accountRepo
.
ClearAntigravityQuotaScopes
(
ctx
,
accountID
)
if
err
:=
s
.
accountRepo
.
ClearAntigravityQuotaScopes
(
ctx
,
accountID
);
err
!=
nil
{
return
err
}
return
s
.
accountRepo
.
ClearModelRateLimits
(
ctx
,
accountID
)
}
}
func
(
s
*
RateLimitService
)
ClearTempUnschedulable
(
ctx
context
.
Context
,
accountID
int64
)
error
{
func
(
s
*
RateLimitService
)
ClearTempUnschedulable
(
ctx
context
.
Context
,
accountID
int64
)
error
{
...
@@ -413,7 +479,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
...
@@ -413,7 +479,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
}
}
if
s
.
tempUnschedCache
!=
nil
{
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
DeleteTempUnsched
(
ctx
,
accountID
);
err
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
DeleteTempUnsched
(
ctx
,
accountID
);
err
!=
nil
{
log
.
Printf
(
"DeleteTempUnsched
failed
for
account
%d: %v
"
,
accountID
,
err
)
s
log
.
Warn
(
"temp_unsched_cache_delete_
failed
"
,
"
account
_id
"
,
accountID
,
"error"
,
err
)
}
}
}
}
return
nil
return
nil
...
@@ -460,7 +526,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
...
@@ -460,7 +526,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
if
s
.
tempUnschedCache
!=
nil
{
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
accountID
,
state
);
err
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
accountID
,
state
);
err
!=
nil
{
log
.
Printf
(
"SetT
emp
U
nsched
failed
for
account
%d: %v
"
,
accountID
,
err
)
s
log
.
Warn
(
"t
emp
_u
nsched
_cache_set_
failed
"
,
"
account
_id
"
,
accountID
,
"error"
,
err
)
}
}
}
}
...
@@ -563,17 +629,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
...
@@ -563,17 +629,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
}
}
if
err
:=
s
.
accountRepo
.
SetTempUnschedulable
(
ctx
,
account
.
ID
,
until
,
reason
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetTempUnschedulable
(
ctx
,
account
.
ID
,
until
,
reason
);
err
!=
nil
{
log
.
Printf
(
"SetT
emp
U
nsched
ulable
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"t
emp
_u
nsched
_set_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
return
false
return
false
}
}
if
s
.
tempUnschedCache
!=
nil
{
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
account
.
ID
,
state
);
err
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
account
.
ID
,
state
);
err
!=
nil
{
log
.
Printf
(
"SetT
emp
U
nsched
cache
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"t
emp
_u
nsched
_
cache
_set_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
log
.
Printf
(
"
A
ccount
%d
temp
unschedulable
until %v (rule %d, code %d)
"
,
account
.
ID
,
until
,
ruleIndex
,
statusCode
)
s
log
.
Info
(
"
a
ccount
_
temp
_
unschedulable
"
,
"account_id
"
,
account
.
ID
,
"
until
"
,
until
,
"rule_index"
,
ruleIndex
,
"status_code"
,
statusCode
)
return
true
return
true
}
}
...
@@ -597,13 +663,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
...
@@ -597,13 +663,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
// 获取系统设置
// 获取系统设置
if
s
.
settingService
==
nil
{
if
s
.
settingService
==
nil
{
log
.
Printf
(
"[S
tream
T
imeout
]
setting
S
ervice
not configured, skipping timeout handling for
account
%
d"
,
account
.
ID
)
s
log
.
Warn
(
"s
tream
_t
imeout
_
setting
_s
ervice
_missing"
,
"
account
_i
d"
,
account
.
ID
)
return
false
return
false
}
}
settings
,
err
:=
s
.
settingService
.
GetStreamTimeoutSettings
(
ctx
)
settings
,
err
:=
s
.
settingService
.
GetStreamTimeoutSettings
(
ctx
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] Failed to
get
settings
: %v
"
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_
get
_
settings
_failed"
,
"account_id"
,
account
.
ID
,
"error
"
,
err
)
return
false
return
false
}
}
...
@@ -620,14 +686,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
...
@@ -620,14 +686,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
if
s
.
timeoutCounterCache
!=
nil
{
if
s
.
timeoutCounterCache
!=
nil
{
count
,
err
=
s
.
timeoutCounterCache
.
IncrementTimeoutCount
(
ctx
,
account
.
ID
,
settings
.
ThresholdWindowMinutes
)
count
,
err
=
s
.
timeoutCounterCache
.
IncrementTimeoutCount
(
ctx
,
account
.
ID
,
settings
.
ThresholdWindowMinutes
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] Failed to increment timeout count for account %d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_increment_count_failed"
,
"account_id
"
,
account
.
ID
,
"error"
,
err
)
// 继续处理,使用 count=1
// 继续处理,使用 count=1
count
=
1
count
=
1
}
}
}
}
log
.
Printf
(
"[StreamTimeout] Account %d timeout count: %d/%d (window: %d min, model: %s)"
,
slog
.
Info
(
"stream_timeout_count"
,
"account_id"
,
account
.
ID
,
"count"
,
count
,
"threshold"
,
settings
.
ThresholdCount
,
"window_minutes"
,
settings
.
ThresholdWindowMinutes
,
"model"
,
model
)
account
.
ID
,
count
,
settings
.
ThresholdCount
,
settings
.
ThresholdWindowMinutes
,
model
)
// 检查是否达到阈值
// 检查是否达到阈值
if
count
<
int64
(
settings
.
ThresholdCount
)
{
if
count
<
int64
(
settings
.
ThresholdCount
)
{
...
@@ -668,24 +733,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
...
@@ -668,24 +733,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
}
}
if
err
:=
s
.
accountRepo
.
SetTempUnschedulable
(
ctx
,
account
.
ID
,
until
,
reason
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetTempUnschedulable
(
ctx
,
account
.
ID
,
until
,
reason
);
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] SetT
emp
U
nsched
ulable
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_set_t
emp
_u
nsched
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
return
false
return
false
}
}
if
s
.
tempUnschedCache
!=
nil
{
if
s
.
tempUnschedCache
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
account
.
ID
,
state
);
err
!=
nil
{
if
err
:=
s
.
tempUnschedCache
.
SetTempUnsched
(
ctx
,
account
.
ID
,
state
);
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] SetT
emp
U
nsched
cache
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_set_t
emp
_u
nsched
_
cache
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
// 重置超时计数
// 重置超时计数
if
s
.
timeoutCounterCache
!=
nil
{
if
s
.
timeoutCounterCache
!=
nil
{
if
err
:=
s
.
timeoutCounterCache
.
ResetTimeoutCount
(
ctx
,
account
.
ID
);
err
!=
nil
{
if
err
:=
s
.
timeoutCounterCache
.
ResetTimeoutCount
(
ctx
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] R
eset
TimeoutC
ount
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_r
eset
_c
ount
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
log
.
Printf
(
"[S
tream
T
imeout
] Account %d marked as
temp
unschedulable
until %v (model: %s)
"
,
account
.
ID
,
until
,
model
)
s
log
.
Info
(
"s
tream
_t
imeout
_
temp
_
unschedulable
"
,
"account_id
"
,
account
.
ID
,
"
until
"
,
until
,
"model"
,
model
)
return
true
return
true
}
}
...
@@ -694,17 +759,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun
...
@@ -694,17 +759,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun
errorMsg
:=
"Stream data interval timeout (repeated failures) for model: "
+
model
errorMsg
:=
"Stream data interval timeout (repeated failures) for model: "
+
model
if
err
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
errorMsg
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
SetError
(
ctx
,
account
.
ID
,
errorMsg
);
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] SetE
rror
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_set_e
rror
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
return
false
return
false
}
}
// 重置超时计数
// 重置超时计数
if
s
.
timeoutCounterCache
!=
nil
{
if
s
.
timeoutCounterCache
!=
nil
{
if
err
:=
s
.
timeoutCounterCache
.
ResetTimeoutCount
(
ctx
,
account
.
ID
);
err
!=
nil
{
if
err
:=
s
.
timeoutCounterCache
.
ResetTimeoutCount
(
ctx
,
account
.
ID
);
err
!=
nil
{
log
.
Printf
(
"[S
tream
T
imeout
] R
eset
TimeoutC
ount
failed
for
account
%d: %v
"
,
account
.
ID
,
err
)
s
log
.
Warn
(
"s
tream
_t
imeout
_r
eset
_c
ount
_
failed
"
,
"
account
_id
"
,
account
.
ID
,
"error"
,
err
)
}
}
}
}
log
.
Printf
(
"[S
tream
T
imeout
] A
ccount
%d marked as error (model: %s)
"
,
account
.
ID
,
model
)
s
log
.
Warn
(
"s
tream
_t
imeout
_a
ccount
_error"
,
"account_id
"
,
account
.
ID
,
"model"
,
model
)
return
true
return
true
}
}
backend/internal/service/ratelimit_service_401_test.go
0 → 100644
View file @
6901b64f
//go:build unit
package
service
import
(
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type
rateLimitAccountRepoStub
struct
{
mockAccountRepoForGemini
setErrorCalls
int
tempCalls
int
lastErrorMsg
string
}
func
(
r
*
rateLimitAccountRepoStub
)
SetError
(
ctx
context
.
Context
,
id
int64
,
errorMsg
string
)
error
{
r
.
setErrorCalls
++
r
.
lastErrorMsg
=
errorMsg
return
nil
}
func
(
r
*
rateLimitAccountRepoStub
)
SetTempUnschedulable
(
ctx
context
.
Context
,
id
int64
,
until
time
.
Time
,
reason
string
)
error
{
r
.
tempCalls
++
return
nil
}
type
tokenCacheInvalidatorRecorder
struct
{
accounts
[]
*
Account
err
error
}
func
(
r
*
tokenCacheInvalidatorRecorder
)
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
{
r
.
accounts
=
append
(
r
.
accounts
,
account
)
return
r
.
err
}
func
TestRateLimitService_HandleUpstreamError_OAuth401MarksError
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
platform
string
}{
{
name
:
"gemini"
,
platform
:
PlatformGemini
},
{
name
:
"antigravity"
,
platform
:
PlatformAntigravity
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
invalidator
:=
&
tokenCacheInvalidatorRecorder
{}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
service
.
SetTokenCacheInvalidator
(
invalidator
)
account
:=
&
Account
{
ID
:
100
,
Platform
:
tt
.
platform
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
[]
any
{
map
[
string
]
any
{
"error_code"
:
401
,
"keywords"
:
[]
any
{
"unauthorized"
},
"duration_minutes"
:
30
,
"description"
:
"custom rule"
,
},
},
},
}
shouldDisable
:=
service
.
HandleUpstreamError
(
context
.
Background
(),
account
,
401
,
http
.
Header
{},
[]
byte
(
"unauthorized"
))
require
.
True
(
t
,
shouldDisable
)
require
.
Equal
(
t
,
1
,
repo
.
setErrorCalls
)
require
.
Equal
(
t
,
0
,
repo
.
tempCalls
)
require
.
Contains
(
t
,
repo
.
lastErrorMsg
,
"Authentication failed (401)"
)
require
.
Len
(
t
,
invalidator
.
accounts
,
1
)
})
}
}
func
TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
invalidator
:=
&
tokenCacheInvalidatorRecorder
{
err
:
errors
.
New
(
"boom"
)}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
service
.
SetTokenCacheInvalidator
(
invalidator
)
account
:=
&
Account
{
ID
:
101
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
shouldDisable
:=
service
.
HandleUpstreamError
(
context
.
Background
(),
account
,
401
,
http
.
Header
{},
[]
byte
(
"unauthorized"
))
require
.
True
(
t
,
shouldDisable
)
require
.
Equal
(
t
,
1
,
repo
.
setErrorCalls
)
require
.
Len
(
t
,
invalidator
.
accounts
,
1
)
}
func
TestRateLimitService_HandleUpstreamError_NonOAuth401
(
t
*
testing
.
T
)
{
repo
:=
&
rateLimitAccountRepoStub
{}
invalidator
:=
&
tokenCacheInvalidatorRecorder
{}
service
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
service
.
SetTokenCacheInvalidator
(
invalidator
)
account
:=
&
Account
{
ID
:
102
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
}
shouldDisable
:=
service
.
HandleUpstreamError
(
context
.
Background
(),
account
,
401
,
http
.
Header
{},
[]
byte
(
"unauthorized"
))
require
.
True
(
t
,
shouldDisable
)
require
.
Equal
(
t
,
1
,
repo
.
setErrorCalls
)
require
.
Empty
(
t
,
invalidator
.
accounts
)
}
backend/internal/service/session_limit_cache.go
0 → 100644
View file @
6901b64f
package
service
import
(
"context"
"time"
)
// SessionLimitCache 管理账号级别的活跃会话跟踪
// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制
//
// Key 格式: session_limit:account:{accountID}
// 数据结构: Sorted Set (member=sessionUUID, score=timestamp)
//
// 会话在空闲超时后自动过期,无需手动清理
type
SessionLimitCache
interface
{
// RegisterSession 注册会话活动
// - 如果会话已存在,刷新其时间戳并返回 true
// - 如果会话不存在且活跃会话数 < maxSessions,添加新会话并返回 true
// - 如果会话不存在且活跃会话数 >= maxSessions,返回 false(拒绝)
//
// 参数:
// accountID: 账号 ID
// sessionUUID: 从 metadata.user_id 中提取的会话 UUID
// maxSessions: 最大并发会话数限制
// idleTimeout: 会话空闲超时时间
//
// 返回:
// allowed: true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
// error: 操作错误
RegisterSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
maxSessions
int
,
idleTimeout
time
.
Duration
)
(
allowed
bool
,
err
error
)
// RefreshSession 刷新现有会话的时间戳
// 用于活跃会话保持活动状态
RefreshSession
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
,
idleTimeout
time
.
Duration
)
error
// GetActiveSessionCount 获取当前活跃会话数
// 返回未过期的会话数量
GetActiveSessionCount
(
ctx
context
.
Context
,
accountID
int64
)
(
int
,
error
)
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
// 返回 map[accountID]count,查询失败的账号不在 map 中
GetActiveSessionCountBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
int
,
error
)
// IsSessionActive 检查特定会话是否活跃(未过期)
IsSessionActive
(
ctx
context
.
Context
,
accountID
int64
,
sessionUUID
string
)
(
bool
,
error
)
// ========== 5h窗口费用缓存 ==========
// Key 格式: window_cost:account:{accountID}
// 用于缓存账号在当前5h窗口内的标准费用,减少数据库聚合查询压力
// GetWindowCost 获取缓存的窗口费用
// 返回 (cost, true, nil) 如果缓存命中
// 返回 (0, false, nil) 如果缓存未命中
// 返回 (0, false, err) 如果发生错误
GetWindowCost
(
ctx
context
.
Context
,
accountID
int64
)
(
cost
float64
,
hit
bool
,
err
error
)
// SetWindowCost 设置窗口费用缓存
SetWindowCost
(
ctx
context
.
Context
,
accountID
int64
,
cost
float64
)
error
// GetWindowCostBatch 批量获取窗口费用缓存
// 返回 map[accountID]cost,缓存未命中的账号不在 map 中
GetWindowCostBatch
(
ctx
context
.
Context
,
accountIDs
[]
int64
)
(
map
[
int64
]
float64
,
error
)
}
backend/internal/service/timing_wheel_service.go
View file @
6901b64f
package
service
package
service
import
(
import
(
"fmt"
"log"
"log"
"sync"
"sync"
"time"
"time"
...
@@ -8,6 +9,8 @@ import (
...
@@ -8,6 +9,8 @@ import (
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/collection"
)
)
var
newTimingWheel
=
collection
.
NewTimingWheel
// TimingWheelService wraps go-zero's TimingWheel for task scheduling
// TimingWheelService wraps go-zero's TimingWheel for task scheduling
type
TimingWheelService
struct
{
type
TimingWheelService
struct
{
tw
*
collection
.
TimingWheel
tw
*
collection
.
TimingWheel
...
@@ -15,18 +18,18 @@ type TimingWheelService struct {
...
@@ -15,18 +18,18 @@ type TimingWheelService struct {
}
}
// NewTimingWheelService creates a new TimingWheelService instance
// NewTimingWheelService creates a new TimingWheelService instance
func
NewTimingWheelService
()
*
TimingWheelService
{
func
NewTimingWheelService
()
(
*
TimingWheelService
,
error
)
{
// 1 second tick, 3600 slots = supports up to 1 hour delay
// 1 second tick, 3600 slots = supports up to 1 hour delay
// execute function: runs func() type tasks
// execute function: runs func() type tasks
tw
,
err
:=
collection
.
N
ewTimingWheel
(
1
*
time
.
Second
,
3600
,
func
(
key
,
value
any
)
{
tw
,
err
:=
n
ewTimingWheel
(
1
*
time
.
Second
,
3600
,
func
(
key
,
value
any
)
{
if
fn
,
ok
:=
value
.
(
func
());
ok
{
if
fn
,
ok
:=
value
.
(
func
());
ok
{
fn
()
fn
()
}
}
})
})
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
return
nil
,
fmt
.
Errorf
(
"创建 timing wheel 失败: %w"
,
err
)
}
}
return
&
TimingWheelService
{
tw
:
tw
}
return
&
TimingWheelService
{
tw
:
tw
}
,
nil
}
}
// Start starts the timing wheel
// Start starts the timing wheel
...
...
backend/internal/service/timing_wheel_service_test.go
0 → 100644
View file @
6901b64f
package
service
import
(
"errors"
"sync/atomic"
"testing"
"time"
"github.com/zeromicro/go-zero/core/collection"
)
func
TestNewTimingWheelService_InitFail_NoPanicAndReturnError
(
t
*
testing
.
T
)
{
original
:=
newTimingWheel
t
.
Cleanup
(
func
()
{
newTimingWheel
=
original
})
newTimingWheel
=
func
(
_
time
.
Duration
,
_
int
,
_
collection
.
Execute
)
(
*
collection
.
TimingWheel
,
error
)
{
return
nil
,
errors
.
New
(
"boom"
)
}
svc
,
err
:=
NewTimingWheelService
()
if
err
==
nil
{
t
.
Fatalf
(
"期望返回 error,但得到 nil"
)
}
if
svc
!=
nil
{
t
.
Fatalf
(
"期望返回 nil svc,但得到非空"
)
}
}
func
TestNewTimingWheelService_Success
(
t
*
testing
.
T
)
{
svc
,
err
:=
NewTimingWheelService
()
if
err
!=
nil
{
t
.
Fatalf
(
"期望 err 为 nil,但得到: %v"
,
err
)
}
if
svc
==
nil
{
t
.
Fatalf
(
"期望 svc 非空,但得到 nil"
)
}
svc
.
Stop
()
}
func
TestNewTimingWheelService_ExecuteCallbackRunsFunc
(
t
*
testing
.
T
)
{
original
:=
newTimingWheel
t
.
Cleanup
(
func
()
{
newTimingWheel
=
original
})
var
captured
collection
.
Execute
newTimingWheel
=
func
(
interval
time
.
Duration
,
numSlots
int
,
execute
collection
.
Execute
)
(
*
collection
.
TimingWheel
,
error
)
{
captured
=
execute
return
original
(
interval
,
numSlots
,
execute
)
}
svc
,
err
:=
NewTimingWheelService
()
if
err
!=
nil
{
t
.
Fatalf
(
"期望 err 为 nil,但得到: %v"
,
err
)
}
if
captured
==
nil
{
t
.
Fatalf
(
"期望 captured 非空,但得到 nil"
)
}
called
:=
false
captured
(
"k"
,
func
()
{
called
=
true
})
if
!
called
{
t
.
Fatalf
(
"期望 execute 回调触发传入函数执行"
)
}
svc
.
Stop
()
}
func
TestTimingWheelService_Schedule_ExecutesOnce
(
t
*
testing
.
T
)
{
original
:=
newTimingWheel
t
.
Cleanup
(
func
()
{
newTimingWheel
=
original
})
newTimingWheel
=
func
(
_
time
.
Duration
,
_
int
,
execute
collection
.
Execute
)
(
*
collection
.
TimingWheel
,
error
)
{
return
original
(
10
*
time
.
Millisecond
,
128
,
execute
)
}
svc
,
err
:=
NewTimingWheelService
()
if
err
!=
nil
{
t
.
Fatalf
(
"期望 err 为 nil,但得到: %v"
,
err
)
}
defer
svc
.
Stop
()
ch
:=
make
(
chan
struct
{},
1
)
svc
.
Schedule
(
"once"
,
30
*
time
.
Millisecond
,
func
()
{
ch
<-
struct
{}{}
})
select
{
case
<-
ch
:
case
<-
time
.
After
(
500
*
time
.
Millisecond
)
:
t
.
Fatalf
(
"等待任务执行超时"
)
}
select
{
case
<-
ch
:
t
.
Fatalf
(
"任务不应重复执行"
)
case
<-
time
.
After
(
80
*
time
.
Millisecond
)
:
}
}
func
TestTimingWheelService_Cancel_PreventsExecution
(
t
*
testing
.
T
)
{
original
:=
newTimingWheel
t
.
Cleanup
(
func
()
{
newTimingWheel
=
original
})
newTimingWheel
=
func
(
_
time
.
Duration
,
_
int
,
execute
collection
.
Execute
)
(
*
collection
.
TimingWheel
,
error
)
{
return
original
(
10
*
time
.
Millisecond
,
128
,
execute
)
}
svc
,
err
:=
NewTimingWheelService
()
if
err
!=
nil
{
t
.
Fatalf
(
"期望 err 为 nil,但得到: %v"
,
err
)
}
defer
svc
.
Stop
()
ch
:=
make
(
chan
struct
{},
1
)
svc
.
Schedule
(
"cancel"
,
80
*
time
.
Millisecond
,
func
()
{
ch
<-
struct
{}{}
})
svc
.
Cancel
(
"cancel"
)
select
{
case
<-
ch
:
t
.
Fatalf
(
"任务已取消,不应执行"
)
case
<-
time
.
After
(
200
*
time
.
Millisecond
)
:
}
}
func
TestTimingWheelService_ScheduleRecurring_ExecutesMultipleTimes
(
t
*
testing
.
T
)
{
original
:=
newTimingWheel
t
.
Cleanup
(
func
()
{
newTimingWheel
=
original
})
newTimingWheel
=
func
(
_
time
.
Duration
,
_
int
,
execute
collection
.
Execute
)
(
*
collection
.
TimingWheel
,
error
)
{
return
original
(
10
*
time
.
Millisecond
,
128
,
execute
)
}
svc
,
err
:=
NewTimingWheelService
()
if
err
!=
nil
{
t
.
Fatalf
(
"期望 err 为 nil,但得到: %v"
,
err
)
}
defer
svc
.
Stop
()
var
count
int32
svc
.
ScheduleRecurring
(
"rec"
,
30
*
time
.
Millisecond
,
func
()
{
atomic
.
AddInt32
(
&
count
,
1
)
})
deadline
:=
time
.
Now
()
.
Add
(
500
*
time
.
Millisecond
)
for
atomic
.
LoadInt32
(
&
count
)
<
2
&&
time
.
Now
()
.
Before
(
deadline
)
{
time
.
Sleep
(
10
*
time
.
Millisecond
)
}
if
atomic
.
LoadInt32
(
&
count
)
<
2
{
t
.
Fatalf
(
"期望周期任务至少执行 2 次,但只执行了 %d 次"
,
atomic
.
LoadInt32
(
&
count
))
}
}
backend/internal/service/token_cache_invalidator.go
0 → 100644
View file @
6901b64f
package
service
import
"context"
type
TokenCacheInvalidator
interface
{
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
}
type
CompositeTokenCacheInvalidator
struct
{
cache
GeminiTokenCache
// 统一使用一个缓存接口,通过缓存键前缀区分平台
}
func
NewCompositeTokenCacheInvalidator
(
cache
GeminiTokenCache
)
*
CompositeTokenCacheInvalidator
{
return
&
CompositeTokenCacheInvalidator
{
cache
:
cache
,
}
}
func
(
c
*
CompositeTokenCacheInvalidator
)
InvalidateToken
(
ctx
context
.
Context
,
account
*
Account
)
error
{
if
c
==
nil
||
c
.
cache
==
nil
||
account
==
nil
{
return
nil
}
if
account
.
Type
!=
AccountTypeOAuth
{
return
nil
}
var
cacheKey
string
switch
account
.
Platform
{
case
PlatformGemini
:
cacheKey
=
GeminiTokenCacheKey
(
account
)
case
PlatformAntigravity
:
cacheKey
=
AntigravityTokenCacheKey
(
account
)
case
PlatformOpenAI
:
cacheKey
=
OpenAITokenCacheKey
(
account
)
case
PlatformAnthropic
:
cacheKey
=
ClaudeTokenCacheKey
(
account
)
default
:
return
nil
}
return
c
.
cache
.
DeleteAccessToken
(
ctx
,
cacheKey
)
}
backend/internal/service/token_cache_invalidator_test.go
0 → 100644
View file @
6901b64f
//go:build unit
package
service
import
(
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type
geminiTokenCacheStub
struct
{
deletedKeys
[]
string
deleteErr
error
}
func
(
s
*
geminiTokenCacheStub
)
GetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
(
string
,
error
)
{
return
""
,
nil
}
func
(
s
*
geminiTokenCacheStub
)
SetAccessToken
(
ctx
context
.
Context
,
cacheKey
string
,
token
string
,
ttl
time
.
Duration
)
error
{
return
nil
}
func
(
s
*
geminiTokenCacheStub
)
DeleteAccessToken
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
s
.
deletedKeys
=
append
(
s
.
deletedKeys
,
cacheKey
)
return
s
.
deleteErr
}
func
(
s
*
geminiTokenCacheStub
)
AcquireRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
,
ttl
time
.
Duration
)
(
bool
,
error
)
{
return
true
,
nil
}
func
(
s
*
geminiTokenCacheStub
)
ReleaseRefreshLock
(
ctx
context
.
Context
,
cacheKey
string
)
error
{
return
nil
}
func
TestCompositeTokenCacheInvalidator_Gemini
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
10
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"project-x"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"gemini:project-x"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_Antigravity
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
99
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"ag-project"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"ag:ag-project"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_OpenAI
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
500
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"openai-token"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"openai:account:500"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_Claude
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
600
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"claude-token"
,
},
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
[]
string
{
"claude:account:600"
},
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_SkipNonOAuth
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
tests
:=
[]
struct
{
name
string
account
*
Account
}{
{
name
:
"gemini_api_key"
,
account
:
&
Account
{
ID
:
1
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeAPIKey
,
},
},
{
name
:
"openai_api_key"
,
account
:
&
Account
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
,
},
},
{
name
:
"claude_api_key"
,
account
:
&
Account
{
ID
:
3
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
,
},
},
{
name
:
"claude_setup_token"
,
account
:
&
Account
{
ID
:
4
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeSetupToken
,
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
cache
.
deletedKeys
=
nil
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
tt
.
account
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
cache
.
deletedKeys
)
})
}
}
func
TestCompositeTokenCacheInvalidator_SkipUnsupportedPlatform
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
account
:=
&
Account
{
ID
:
100
,
Platform
:
"unknown-platform"
,
Type
:
AccountTypeOAuth
,
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_NilCache
(
t
*
testing
.
T
)
{
invalidator
:=
NewCompositeTokenCacheInvalidator
(
nil
)
account
:=
&
Account
{
ID
:
2
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
}
func
TestCompositeTokenCacheInvalidator_NilAccount
(
t
*
testing
.
T
)
{
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
nil
)
require
.
NoError
(
t
,
err
)
require
.
Empty
(
t
,
cache
.
deletedKeys
)
}
func
TestCompositeTokenCacheInvalidator_NilInvalidator
(
t
*
testing
.
T
)
{
var
invalidator
*
CompositeTokenCacheInvalidator
account
:=
&
Account
{
ID
:
5
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
}
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
account
)
require
.
NoError
(
t
,
err
)
}
func
TestCompositeTokenCacheInvalidator_DeleteError
(
t
*
testing
.
T
)
{
expectedErr
:=
errors
.
New
(
"redis connection failed"
)
cache
:=
&
geminiTokenCacheStub
{
deleteErr
:
expectedErr
}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
tests
:=
[]
struct
{
name
string
account
*
Account
}{
{
name
:
"openai_delete_error"
,
account
:
&
Account
{
ID
:
700
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
,
},
},
{
name
:
"claude_delete_error"
,
account
:
&
Account
{
ID
:
800
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
,
},
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
tt
.
account
)
require
.
Error
(
t
,
err
)
require
.
Equal
(
t
,
expectedErr
,
err
)
})
}
}
func
TestCompositeTokenCacheInvalidator_AllPlatformsIntegration
(
t
*
testing
.
T
)
{
// 测试所有平台的缓存键生成和删除
cache
:=
&
geminiTokenCacheStub
{}
invalidator
:=
NewCompositeTokenCacheInvalidator
(
cache
)
accounts
:=
[]
*
Account
{
{
ID
:
1
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"gemini-proj"
}},
{
ID
:
2
,
Platform
:
PlatformAntigravity
,
Type
:
AccountTypeOAuth
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"ag-proj"
}},
{
ID
:
3
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeOAuth
},
{
ID
:
4
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeOAuth
},
}
expectedKeys
:=
[]
string
{
"gemini:gemini-proj"
,
"ag:ag-proj"
,
"openai:account:3"
,
"claude:account:4"
,
}
for
_
,
acc
:=
range
accounts
{
err
:=
invalidator
.
InvalidateToken
(
context
.
Background
(),
acc
)
require
.
NoError
(
t
,
err
)
}
require
.
Equal
(
t
,
expectedKeys
,
cache
.
deletedKeys
)
}
backend/internal/service/token_cache_key.go
0 → 100644
View file @
6901b64f
package
service
import
"strconv"
// OpenAITokenCacheKey 生成 OpenAI OAuth 账号的缓存键
// 格式: "openai:account:{account_id}"
func
OpenAITokenCacheKey
(
account
*
Account
)
string
{
return
"openai:account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
// ClaudeTokenCacheKey 生成 Claude (Anthropic) OAuth 账号的缓存键
// 格式: "claude:account:{account_id}"
func
ClaudeTokenCacheKey
(
account
*
Account
)
string
{
return
"claude:account:"
+
strconv
.
FormatInt
(
account
.
ID
,
10
)
}
backend/internal/service/token_cache_key_test.go
0 → 100644
View file @
6901b64f
//go:build unit
package
service
import
(
"testing"
"github.com/stretchr/testify/require"
)
func
TestGeminiTokenCacheKey
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
expected
string
}{
{
name
:
"with_project_id"
,
account
:
&
Account
{
ID
:
100
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"my-project-123"
,
},
},
expected
:
"gemini:my-project-123"
,
},
{
name
:
"project_id_with_whitespace"
,
account
:
&
Account
{
ID
:
101
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
" project-with-spaces "
,
},
},
expected
:
"gemini:project-with-spaces"
,
},
{
name
:
"empty_project_id_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
102
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
""
,
},
},
expected
:
"gemini:account:102"
,
},
{
name
:
"whitespace_only_project_id_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
103
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
" "
,
},
},
expected
:
"gemini:account:103"
,
},
{
name
:
"no_project_id_key_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
104
,
Credentials
:
map
[
string
]
any
{},
},
expected
:
"gemini:account:104"
,
},
{
name
:
"nil_credentials_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
105
,
Credentials
:
nil
,
},
expected
:
"gemini:account:105"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
GeminiTokenCacheKey
(
tt
.
account
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
func
TestAntigravityTokenCacheKey
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
expected
string
}{
{
name
:
"with_project_id"
,
account
:
&
Account
{
ID
:
200
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
"ag-project-456"
,
},
},
expected
:
"ag:ag-project-456"
,
},
{
name
:
"project_id_with_whitespace"
,
account
:
&
Account
{
ID
:
201
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
" ag-project-spaces "
,
},
},
expected
:
"ag:ag-project-spaces"
,
},
{
name
:
"empty_project_id_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
202
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
""
,
},
},
expected
:
"ag:account:202"
,
},
{
name
:
"whitespace_only_project_id_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
203
,
Credentials
:
map
[
string
]
any
{
"project_id"
:
" "
,
},
},
expected
:
"ag:account:203"
,
},
{
name
:
"no_project_id_key_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
204
,
Credentials
:
map
[
string
]
any
{},
},
expected
:
"ag:account:204"
,
},
{
name
:
"nil_credentials_fallback_to_account_id"
,
account
:
&
Account
{
ID
:
205
,
Credentials
:
nil
,
},
expected
:
"ag:account:205"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
AntigravityTokenCacheKey
(
tt
.
account
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
func
TestOpenAITokenCacheKey
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
expected
string
}{
{
name
:
"basic_account"
,
account
:
&
Account
{
ID
:
300
,
},
expected
:
"openai:account:300"
,
},
{
name
:
"account_with_credentials"
,
account
:
&
Account
{
ID
:
301
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"test-token"
,
},
},
expected
:
"openai:account:301"
,
},
{
name
:
"account_id_zero"
,
account
:
&
Account
{
ID
:
0
,
},
expected
:
"openai:account:0"
,
},
{
name
:
"large_account_id"
,
account
:
&
Account
{
ID
:
9999999999
,
},
expected
:
"openai:account:9999999999"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
OpenAITokenCacheKey
(
tt
.
account
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
func
TestClaudeTokenCacheKey
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
expected
string
}{
{
name
:
"basic_account"
,
account
:
&
Account
{
ID
:
400
,
},
expected
:
"claude:account:400"
,
},
{
name
:
"account_with_credentials"
,
account
:
&
Account
{
ID
:
401
,
Credentials
:
map
[
string
]
any
{
"access_token"
:
"claude-token"
,
},
},
expected
:
"claude:account:401"
,
},
{
name
:
"account_id_zero"
,
account
:
&
Account
{
ID
:
0
,
},
expected
:
"claude:account:0"
,
},
{
name
:
"large_account_id"
,
account
:
&
Account
{
ID
:
9999999999
,
},
expected
:
"claude:account:9999999999"
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
result
:=
ClaudeTokenCacheKey
(
tt
.
account
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
func
TestCacheKeyUniqueness
(
t
*
testing
.
T
)
{
// 确保不同平台的缓存键不会冲突
account
:=
&
Account
{
ID
:
123
}
openaiKey
:=
OpenAITokenCacheKey
(
account
)
claudeKey
:=
ClaudeTokenCacheKey
(
account
)
require
.
NotEqual
(
t
,
openaiKey
,
claudeKey
,
"OpenAI and Claude cache keys should be different"
)
require
.
Contains
(
t
,
openaiKey
,
"openai:"
)
require
.
Contains
(
t
,
claudeKey
,
"claude:"
)
}
backend/internal/service/token_refresh_service.go
View file @
6901b64f
...
@@ -14,9 +14,10 @@ import (
...
@@ -14,9 +14,10 @@ import (
// TokenRefreshService OAuth token自动刷新服务
// TokenRefreshService OAuth token自动刷新服务
// 定期检查并刷新即将过期的token
// 定期检查并刷新即将过期的token
type
TokenRefreshService
struct
{
type
TokenRefreshService
struct
{
accountRepo
AccountRepository
accountRepo
AccountRepository
refreshers
[]
TokenRefresher
refreshers
[]
TokenRefresher
cfg
*
config
.
TokenRefreshConfig
cfg
*
config
.
TokenRefreshConfig
cacheInvalidator
TokenCacheInvalidator
stopCh
chan
struct
{}
stopCh
chan
struct
{}
wg
sync
.
WaitGroup
wg
sync
.
WaitGroup
...
@@ -29,12 +30,14 @@ func NewTokenRefreshService(
...
@@ -29,12 +30,14 @@ func NewTokenRefreshService(
openaiOAuthService
*
OpenAIOAuthService
,
openaiOAuthService
*
OpenAIOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
geminiOAuthService
*
GeminiOAuthService
,
antigravityOAuthService
*
AntigravityOAuthService
,
antigravityOAuthService
*
AntigravityOAuthService
,
cacheInvalidator
TokenCacheInvalidator
,
cfg
*
config
.
Config
,
cfg
*
config
.
Config
,
)
*
TokenRefreshService
{
)
*
TokenRefreshService
{
s
:=
&
TokenRefreshService
{
s
:=
&
TokenRefreshService
{
accountRepo
:
accountRepo
,
accountRepo
:
accountRepo
,
cfg
:
&
cfg
.
TokenRefresh
,
cfg
:
&
cfg
.
TokenRefresh
,
stopCh
:
make
(
chan
struct
{}),
cacheInvalidator
:
cacheInvalidator
,
stopCh
:
make
(
chan
struct
{}),
}
}
// 注册平台特定的刷新器
// 注册平台特定的刷新器
...
@@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
...
@@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
if
err
:=
s
.
accountRepo
.
Update
(
ctx
,
account
);
err
!=
nil
{
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
err
)
return
fmt
.
Errorf
(
"failed to save credentials: %w"
,
err
)
}
}
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
if
s
.
cacheInvalidator
!=
nil
&&
account
.
Type
==
AccountTypeOAuth
{
if
err
:=
s
.
cacheInvalidator
.
InvalidateToken
(
ctx
,
account
);
err
!=
nil
{
log
.
Printf
(
"[TokenRefresh] Failed to invalidate token cache for account %d: %v"
,
account
.
ID
,
err
)
}
else
{
log
.
Printf
(
"[TokenRefresh] Token cache invalidated for account %d"
,
account
.
ID
)
}
}
return
nil
return
nil
}
}
...
...
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment