Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
7b156489
Commit
7b156489
authored
Feb 07, 2026
by
shaw
Browse files
fix: make error passthrough effective for non-failover upstream errors
parent
76d242e0
Changes
10
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/gateway_handler.go
View file @
7b156489
...
@@ -135,6 +135,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
...
@@ -135,6 +135,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Track if we've started streaming (for error handling)
// Track if we've started streaming (for error handling)
streamStarted
:=
false
streamStarted
:=
false
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if
h
.
errorPassthroughService
!=
nil
{
service
.
BindErrorPassthroughService
(
c
,
h
.
errorPassthroughService
)
}
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
subscription
,
_
:=
middleware2
.
GetSubscriptionFromContext
(
c
)
subscription
,
_
:=
middleware2
.
GetSubscriptionFromContext
(
c
)
...
...
backend/internal/handler/gemini_v1beta_handler.go
View file @
7b156489
...
@@ -207,6 +207,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
...
@@ -207,6 +207,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 1) user concurrency slot
// 1) user concurrency slot
streamStarted
:=
false
streamStarted
:=
false
if
h
.
errorPassthroughService
!=
nil
{
service
.
BindErrorPassthroughService
(
c
,
h
.
errorPassthroughService
)
}
userReleaseFunc
,
err
:=
geminiConcurrency
.
AcquireUserSlotWithWait
(
c
,
authSubject
.
UserID
,
authSubject
.
Concurrency
,
stream
,
&
streamStarted
)
userReleaseFunc
,
err
:=
geminiConcurrency
.
AcquireUserSlotWithWait
(
c
,
authSubject
.
UserID
,
authSubject
.
Concurrency
,
stream
,
&
streamStarted
)
if
err
!=
nil
{
if
err
!=
nil
{
googleError
(
c
,
http
.
StatusTooManyRequests
,
err
.
Error
())
googleError
(
c
,
http
.
StatusTooManyRequests
,
err
.
Error
())
...
...
backend/internal/handler/openai_gateway_handler.go
View file @
7b156489
...
@@ -149,6 +149,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
...
@@ -149,6 +149,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Track if we've started streaming (for error handling)
// Track if we've started streaming (for error handling)
streamStarted
:=
false
streamStarted
:=
false
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if
h
.
errorPassthroughService
!=
nil
{
service
.
BindErrorPassthroughService
(
c
,
h
.
errorPassthroughService
)
}
// Get subscription info (may be nil)
// Get subscription info (may be nil)
subscription
,
_
:=
middleware2
.
GetSubscriptionFromContext
(
c
)
subscription
,
_
:=
middleware2
.
GetSubscriptionFromContext
(
c
)
...
...
backend/internal/service/error_passthrough_runtime.go
0 → 100644
View file @
7b156489
package
service
import
"github.com/gin-gonic/gin"
const
errorPassthroughServiceContextKey
=
"error_passthrough_service"
// BindErrorPassthroughService 将错误透传服务绑定到请求上下文,供 service 层在非 failover 场景下复用规则。
func
BindErrorPassthroughService
(
c
*
gin
.
Context
,
svc
*
ErrorPassthroughService
)
{
if
c
==
nil
||
svc
==
nil
{
return
}
c
.
Set
(
errorPassthroughServiceContextKey
,
svc
)
}
func
getBoundErrorPassthroughService
(
c
*
gin
.
Context
)
*
ErrorPassthroughService
{
if
c
==
nil
{
return
nil
}
v
,
ok
:=
c
.
Get
(
errorPassthroughServiceContextKey
)
if
!
ok
{
return
nil
}
svc
,
ok
:=
v
.
(
*
ErrorPassthroughService
)
if
!
ok
{
return
nil
}
return
svc
}
// applyErrorPassthroughRule 按规则改写错误响应;未命中时返回默认响应参数。
func
applyErrorPassthroughRule
(
c
*
gin
.
Context
,
platform
string
,
upstreamStatus
int
,
responseBody
[]
byte
,
defaultStatus
int
,
defaultErrType
string
,
defaultErrMsg
string
,
)
(
status
int
,
errType
string
,
errMsg
string
,
matched
bool
)
{
status
=
defaultStatus
errType
=
defaultErrType
errMsg
=
defaultErrMsg
svc
:=
getBoundErrorPassthroughService
(
c
)
if
svc
==
nil
{
return
status
,
errType
,
errMsg
,
false
}
rule
:=
svc
.
MatchRule
(
platform
,
upstreamStatus
,
responseBody
)
if
rule
==
nil
{
return
status
,
errType
,
errMsg
,
false
}
status
=
upstreamStatus
if
!
rule
.
PassthroughCode
&&
rule
.
ResponseCode
!=
nil
{
status
=
*
rule
.
ResponseCode
}
errMsg
=
ExtractUpstreamErrorMessage
(
responseBody
)
if
!
rule
.
PassthroughBody
&&
rule
.
CustomMessage
!=
nil
{
errMsg
=
*
rule
.
CustomMessage
}
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
errType
=
"upstream_error"
return
status
,
errType
,
errMsg
,
true
}
backend/internal/service/error_passthrough_runtime_test.go
0 → 100644
View file @
7b156489
package
service
import
(
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func
TestApplyErrorPassthroughRule_NoBoundService
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
status
,
errType
,
errMsg
,
matched
:=
applyErrorPassthroughRule
(
c
,
PlatformAnthropic
,
http
.
StatusUnprocessableEntity
,
[]
byte
(
`{"error":{"message":"invalid schema"}}`
),
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed"
,
)
assert
.
False
(
t
,
matched
)
assert
.
Equal
(
t
,
http
.
StatusBadGateway
,
status
)
assert
.
Equal
(
t
,
"upstream_error"
,
errType
)
assert
.
Equal
(
t
,
"Upstream request failed"
,
errMsg
)
}
func
TestGatewayHandleErrorResponse_NoRuleKeepsDefault
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
svc
:=
&
GatewayService
{}
respBody
:=
[]
byte
(
`{"error":{"message":"Invalid schema for field messages"}}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusUnprocessableEntity
,
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
Header
:
http
.
Header
{},
}
account
:=
&
Account
{
ID
:
11
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
}
_
,
err
:=
svc
.
handleErrorResponse
(
context
.
Background
(),
resp
,
c
,
account
)
require
.
Error
(
t
,
err
)
assert
.
Equal
(
t
,
http
.
StatusBadGateway
,
rec
.
Code
)
var
payload
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
payload
))
errField
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
assert
.
Equal
(
t
,
"upstream_error"
,
errField
[
"type"
])
assert
.
Equal
(
t
,
"Upstream request failed"
,
errField
[
"message"
])
}
func
TestOpenAIHandleErrorResponse_NoRuleKeepsDefault
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
svc
:=
&
OpenAIGatewayService
{}
respBody
:=
[]
byte
(
`{"error":{"message":"Invalid schema for field messages"}}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusUnprocessableEntity
,
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
Header
:
http
.
Header
{},
}
account
:=
&
Account
{
ID
:
12
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
_
,
err
:=
svc
.
handleErrorResponse
(
context
.
Background
(),
resp
,
c
,
account
)
require
.
Error
(
t
,
err
)
assert
.
Equal
(
t
,
http
.
StatusBadGateway
,
rec
.
Code
)
var
payload
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
payload
))
errField
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
assert
.
Equal
(
t
,
"upstream_error"
,
errField
[
"type"
])
assert
.
Equal
(
t
,
"Upstream request failed"
,
errField
[
"message"
])
}
func
TestGeminiWriteGeminiMappedError_NoRuleKeepsDefault
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
svc
:=
&
GeminiMessagesCompatService
{}
respBody
:=
[]
byte
(
`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`
)
account
:=
&
Account
{
ID
:
13
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeAPIKey
}
err
:=
svc
.
writeGeminiMappedError
(
c
,
account
,
http
.
StatusUnprocessableEntity
,
"req-2"
,
respBody
)
require
.
Error
(
t
,
err
)
assert
.
Equal
(
t
,
http
.
StatusBadRequest
,
rec
.
Code
)
var
payload
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
payload
))
errField
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
assert
.
Equal
(
t
,
"invalid_request_error"
,
errField
[
"type"
])
assert
.
Equal
(
t
,
"Upstream request failed"
,
errField
[
"message"
])
}
func
TestGatewayHandleErrorResponse_AppliesRuleFor422
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ruleSvc
:=
&
ErrorPassthroughService
{}
ruleSvc
.
setLocalCache
([]
*
model
.
ErrorPassthroughRule
{
newNonFailoverPassthroughRule
(
http
.
StatusUnprocessableEntity
,
"invalid schema"
,
http
.
StatusTeapot
,
"上游请求失败"
)})
BindErrorPassthroughService
(
c
,
ruleSvc
)
svc
:=
&
GatewayService
{}
respBody
:=
[]
byte
(
`{"error":{"message":"Invalid schema for field messages"}}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusUnprocessableEntity
,
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
Header
:
http
.
Header
{},
}
account
:=
&
Account
{
ID
:
1
,
Platform
:
PlatformAnthropic
,
Type
:
AccountTypeAPIKey
}
_
,
err
:=
svc
.
handleErrorResponse
(
context
.
Background
(),
resp
,
c
,
account
)
require
.
Error
(
t
,
err
)
assert
.
Equal
(
t
,
http
.
StatusTeapot
,
rec
.
Code
)
var
payload
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
payload
))
errField
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
assert
.
Equal
(
t
,
"upstream_error"
,
errField
[
"type"
])
assert
.
Equal
(
t
,
"上游请求失败"
,
errField
[
"message"
])
}
func
TestOpenAIHandleErrorResponse_AppliesRuleFor422
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ruleSvc
:=
&
ErrorPassthroughService
{}
ruleSvc
.
setLocalCache
([]
*
model
.
ErrorPassthroughRule
{
newNonFailoverPassthroughRule
(
http
.
StatusUnprocessableEntity
,
"invalid schema"
,
http
.
StatusTeapot
,
"OpenAI上游失败"
)})
BindErrorPassthroughService
(
c
,
ruleSvc
)
svc
:=
&
OpenAIGatewayService
{}
respBody
:=
[]
byte
(
`{"error":{"message":"Invalid schema for field messages"}}`
)
resp
:=
&
http
.
Response
{
StatusCode
:
http
.
StatusUnprocessableEntity
,
Body
:
io
.
NopCloser
(
bytes
.
NewReader
(
respBody
)),
Header
:
http
.
Header
{},
}
account
:=
&
Account
{
ID
:
2
,
Platform
:
PlatformOpenAI
,
Type
:
AccountTypeAPIKey
}
_
,
err
:=
svc
.
handleErrorResponse
(
context
.
Background
(),
resp
,
c
,
account
)
require
.
Error
(
t
,
err
)
assert
.
Equal
(
t
,
http
.
StatusTeapot
,
rec
.
Code
)
var
payload
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
payload
))
errField
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
assert
.
Equal
(
t
,
"upstream_error"
,
errField
[
"type"
])
assert
.
Equal
(
t
,
"OpenAI上游失败"
,
errField
[
"message"
])
}
func
TestGeminiWriteGeminiMappedError_AppliesRuleFor422
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
ruleSvc
:=
&
ErrorPassthroughService
{}
ruleSvc
.
setLocalCache
([]
*
model
.
ErrorPassthroughRule
{
newNonFailoverPassthroughRule
(
http
.
StatusUnprocessableEntity
,
"invalid schema"
,
http
.
StatusTeapot
,
"Gemini上游失败"
)})
BindErrorPassthroughService
(
c
,
ruleSvc
)
svc
:=
&
GeminiMessagesCompatService
{}
respBody
:=
[]
byte
(
`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`
)
account
:=
&
Account
{
ID
:
3
,
Platform
:
PlatformGemini
,
Type
:
AccountTypeAPIKey
}
err
:=
svc
.
writeGeminiMappedError
(
c
,
account
,
http
.
StatusUnprocessableEntity
,
"req-1"
,
respBody
)
require
.
Error
(
t
,
err
)
assert
.
Equal
(
t
,
http
.
StatusTeapot
,
rec
.
Code
)
var
payload
map
[
string
]
any
require
.
NoError
(
t
,
json
.
Unmarshal
(
rec
.
Body
.
Bytes
(),
&
payload
))
errField
,
ok
:=
payload
[
"error"
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
assert
.
Equal
(
t
,
"upstream_error"
,
errField
[
"type"
])
assert
.
Equal
(
t
,
"Gemini上游失败"
,
errField
[
"message"
])
}
func
newNonFailoverPassthroughRule
(
statusCode
int
,
keyword
string
,
respCode
int
,
customMessage
string
)
*
model
.
ErrorPassthroughRule
{
return
&
model
.
ErrorPassthroughRule
{
ID
:
1
,
Name
:
"non-failover-rule"
,
Enabled
:
true
,
Priority
:
1
,
ErrorCodes
:
[]
int
{
statusCode
},
Keywords
:
[]
string
{
keyword
},
MatchMode
:
model
.
MatchModeAll
,
PassthroughCode
:
false
,
ResponseCode
:
&
respCode
,
PassthroughBody
:
false
,
CustomMessage
:
&
customMessage
,
}
}
backend/internal/service/error_passthrough_service.go
View file @
7b156489
...
@@ -6,6 +6,7 @@ import (
...
@@ -6,6 +6,7 @@ import (
"sort"
"sort"
"strings"
"strings"
"sync"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/model"
)
)
...
@@ -60,8 +61,11 @@ func NewErrorPassthroughService(
...
@@ -60,8 +61,11 @@ func NewErrorPassthroughService(
// 启动时加载规则到本地缓存
// 启动时加载规则到本地缓存
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
if
err
:=
svc
.
refreshLocalCache
(
ctx
);
err
!=
nil
{
if
err
:=
svc
.
reloadRulesFromDB
(
ctx
);
err
!=
nil
{
log
.
Printf
(
"[ErrorPassthroughService] Failed to load rules on startup: %v"
,
err
)
log
.
Printf
(
"[ErrorPassthroughService] Failed to load rules from DB on startup: %v"
,
err
)
if
fallbackErr
:=
svc
.
refreshLocalCache
(
ctx
);
fallbackErr
!=
nil
{
log
.
Printf
(
"[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v"
,
fallbackErr
)
}
}
}
// 订阅缓存更新通知
// 订阅缓存更新通知
...
@@ -98,7 +102,9 @@ func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorP
...
@@ -98,7 +102,9 @@ func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorP
}
}
// 刷新缓存
// 刷新缓存
s
.
invalidateAndNotify
(
ctx
)
refreshCtx
,
cancel
:=
s
.
newCacheRefreshContext
()
defer
cancel
()
s
.
invalidateAndNotify
(
refreshCtx
)
return
created
,
nil
return
created
,
nil
}
}
...
@@ -115,7 +121,9 @@ func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorP
...
@@ -115,7 +121,9 @@ func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorP
}
}
// 刷新缓存
// 刷新缓存
s
.
invalidateAndNotify
(
ctx
)
refreshCtx
,
cancel
:=
s
.
newCacheRefreshContext
()
defer
cancel
()
s
.
invalidateAndNotify
(
refreshCtx
)
return
updated
,
nil
return
updated
,
nil
}
}
...
@@ -127,7 +135,9 @@ func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
...
@@ -127,7 +135,9 @@ func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
}
}
// 刷新缓存
// 刷新缓存
s
.
invalidateAndNotify
(
ctx
)
refreshCtx
,
cancel
:=
s
.
newCacheRefreshContext
()
defer
cancel
()
s
.
invalidateAndNotify
(
refreshCtx
)
return
nil
return
nil
}
}
...
@@ -189,7 +199,12 @@ func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
...
@@ -189,7 +199,12 @@ func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
}
}
}
}
// 从数据库加载(repo.List 已按 priority 排序)
return
s
.
reloadRulesFromDB
(
ctx
)
}
// 从数据库加载(repo.List 已按 priority 排序)
// 注意:该方法会绕过 cache.Get,确保拿到数据库最新值。
func
(
s
*
ErrorPassthroughService
)
reloadRulesFromDB
(
ctx
context
.
Context
)
error
{
rules
,
err
:=
s
.
repo
.
List
(
ctx
)
rules
,
err
:=
s
.
repo
.
List
(
ctx
)
if
err
!=
nil
{
if
err
!=
nil
{
return
err
return
err
...
@@ -222,11 +237,32 @@ func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughR
...
@@ -222,11 +237,32 @@ func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughR
s
.
localCacheMu
.
Unlock
()
s
.
localCacheMu
.
Unlock
()
}
}
// clearLocalCache 清空本地缓存,避免刷新失败时继续命中陈旧规则。
func
(
s
*
ErrorPassthroughService
)
clearLocalCache
()
{
s
.
localCacheMu
.
Lock
()
s
.
localCache
=
nil
s
.
localCacheMu
.
Unlock
()
}
// newCacheRefreshContext 为写路径缓存同步创建独立上下文,避免受请求取消影响。
func
(
s
*
ErrorPassthroughService
)
newCacheRefreshContext
()
(
context
.
Context
,
context
.
CancelFunc
)
{
return
context
.
WithTimeout
(
context
.
Background
(),
3
*
time
.
Second
)
}
// invalidateAndNotify 使缓存失效并通知其他实例
// invalidateAndNotify 使缓存失效并通知其他实例
func
(
s
*
ErrorPassthroughService
)
invalidateAndNotify
(
ctx
context
.
Context
)
{
func
(
s
*
ErrorPassthroughService
)
invalidateAndNotify
(
ctx
context
.
Context
)
{
// 先失效缓存,避免后续刷新读到陈旧规则。
if
s
.
cache
!=
nil
{
if
err
:=
s
.
cache
.
Invalidate
(
ctx
);
err
!=
nil
{
log
.
Printf
(
"[ErrorPassthroughService] Failed to invalidate cache: %v"
,
err
)
}
}
// 刷新本地缓存
// 刷新本地缓存
if
err
:=
s
.
re
freshLocalCache
(
ctx
);
err
!=
nil
{
if
err
:=
s
.
re
loadRulesFromDB
(
ctx
);
err
!=
nil
{
log
.
Printf
(
"[ErrorPassthroughService] Failed to refresh local cache: %v"
,
err
)
log
.
Printf
(
"[ErrorPassthroughService] Failed to refresh local cache: %v"
,
err
)
// 刷新失败时清空本地缓存,避免继续使用陈旧规则。
s
.
clearLocalCache
()
}
}
// 通知其他实例
// 通知其他实例
...
...
backend/internal/service/error_passthrough_service_test.go
View file @
7b156489
...
@@ -4,6 +4,7 @@ package service
...
@@ -4,6 +4,7 @@ package service
import
(
import
(
"context"
"context"
"errors"
"strings"
"strings"
"testing"
"testing"
...
@@ -14,14 +15,81 @@ import (
...
@@ -14,14 +15,81 @@ import (
// mockErrorPassthroughRepo 用于测试的 mock repository
// mockErrorPassthroughRepo 用于测试的 mock repository
type
mockErrorPassthroughRepo
struct
{
type
mockErrorPassthroughRepo
struct
{
rules
[]
*
model
.
ErrorPassthroughRule
rules
[]
*
model
.
ErrorPassthroughRule
listErr
error
getErr
error
createErr
error
updateErr
error
deleteErr
error
}
type
mockErrorPassthroughCache
struct
{
rules
[]
*
model
.
ErrorPassthroughRule
hasData
bool
getCalled
int
setCalled
int
invalidateCalled
int
notifyCalled
int
}
func
newMockErrorPassthroughCache
(
rules
[]
*
model
.
ErrorPassthroughRule
,
hasData
bool
)
*
mockErrorPassthroughCache
{
return
&
mockErrorPassthroughCache
{
rules
:
cloneRules
(
rules
),
hasData
:
hasData
,
}
}
func
(
m
*
mockErrorPassthroughCache
)
Get
(
ctx
context
.
Context
)
([]
*
model
.
ErrorPassthroughRule
,
bool
)
{
m
.
getCalled
++
if
!
m
.
hasData
{
return
nil
,
false
}
return
cloneRules
(
m
.
rules
),
true
}
func
(
m
*
mockErrorPassthroughCache
)
Set
(
ctx
context
.
Context
,
rules
[]
*
model
.
ErrorPassthroughRule
)
error
{
m
.
setCalled
++
m
.
rules
=
cloneRules
(
rules
)
m
.
hasData
=
true
return
nil
}
func
(
m
*
mockErrorPassthroughCache
)
Invalidate
(
ctx
context
.
Context
)
error
{
m
.
invalidateCalled
++
m
.
rules
=
nil
m
.
hasData
=
false
return
nil
}
func
(
m
*
mockErrorPassthroughCache
)
NotifyUpdate
(
ctx
context
.
Context
)
error
{
m
.
notifyCalled
++
return
nil
}
func
(
m
*
mockErrorPassthroughCache
)
SubscribeUpdates
(
ctx
context
.
Context
,
handler
func
())
{
// 单测中无需订阅行为
}
func
cloneRules
(
rules
[]
*
model
.
ErrorPassthroughRule
)
[]
*
model
.
ErrorPassthroughRule
{
if
rules
==
nil
{
return
nil
}
out
:=
make
([]
*
model
.
ErrorPassthroughRule
,
len
(
rules
))
copy
(
out
,
rules
)
return
out
}
}
func
(
m
*
mockErrorPassthroughRepo
)
List
(
ctx
context
.
Context
)
([]
*
model
.
ErrorPassthroughRule
,
error
)
{
func
(
m
*
mockErrorPassthroughRepo
)
List
(
ctx
context
.
Context
)
([]
*
model
.
ErrorPassthroughRule
,
error
)
{
if
m
.
listErr
!=
nil
{
return
nil
,
m
.
listErr
}
return
m
.
rules
,
nil
return
m
.
rules
,
nil
}
}
func
(
m
*
mockErrorPassthroughRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
ErrorPassthroughRule
,
error
)
{
func
(
m
*
mockErrorPassthroughRepo
)
GetByID
(
ctx
context
.
Context
,
id
int64
)
(
*
model
.
ErrorPassthroughRule
,
error
)
{
if
m
.
getErr
!=
nil
{
return
nil
,
m
.
getErr
}
for
_
,
r
:=
range
m
.
rules
{
for
_
,
r
:=
range
m
.
rules
{
if
r
.
ID
==
id
{
if
r
.
ID
==
id
{
return
r
,
nil
return
r
,
nil
...
@@ -31,12 +99,18 @@ func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*mode
...
@@ -31,12 +99,18 @@ func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*mode
}
}
func
(
m
*
mockErrorPassthroughRepo
)
Create
(
ctx
context
.
Context
,
rule
*
model
.
ErrorPassthroughRule
)
(
*
model
.
ErrorPassthroughRule
,
error
)
{
func
(
m
*
mockErrorPassthroughRepo
)
Create
(
ctx
context
.
Context
,
rule
*
model
.
ErrorPassthroughRule
)
(
*
model
.
ErrorPassthroughRule
,
error
)
{
if
m
.
createErr
!=
nil
{
return
nil
,
m
.
createErr
}
rule
.
ID
=
int64
(
len
(
m
.
rules
)
+
1
)
rule
.
ID
=
int64
(
len
(
m
.
rules
)
+
1
)
m
.
rules
=
append
(
m
.
rules
,
rule
)
m
.
rules
=
append
(
m
.
rules
,
rule
)
return
rule
,
nil
return
rule
,
nil
}
}
func
(
m
*
mockErrorPassthroughRepo
)
Update
(
ctx
context
.
Context
,
rule
*
model
.
ErrorPassthroughRule
)
(
*
model
.
ErrorPassthroughRule
,
error
)
{
func
(
m
*
mockErrorPassthroughRepo
)
Update
(
ctx
context
.
Context
,
rule
*
model
.
ErrorPassthroughRule
)
(
*
model
.
ErrorPassthroughRule
,
error
)
{
if
m
.
updateErr
!=
nil
{
return
nil
,
m
.
updateErr
}
for
i
,
r
:=
range
m
.
rules
{
for
i
,
r
:=
range
m
.
rules
{
if
r
.
ID
==
rule
.
ID
{
if
r
.
ID
==
rule
.
ID
{
m
.
rules
[
i
]
=
rule
m
.
rules
[
i
]
=
rule
...
@@ -47,6 +121,9 @@ func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.Error
...
@@ -47,6 +121,9 @@ func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.Error
}
}
func
(
m
*
mockErrorPassthroughRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
func
(
m
*
mockErrorPassthroughRepo
)
Delete
(
ctx
context
.
Context
,
id
int64
)
error
{
if
m
.
deleteErr
!=
nil
{
return
m
.
deleteErr
}
for
i
,
r
:=
range
m
.
rules
{
for
i
,
r
:=
range
m
.
rules
{
if
r
.
ID
==
id
{
if
r
.
ID
==
id
{
m
.
rules
=
append
(
m
.
rules
[
:
i
],
m
.
rules
[
i
+
1
:
]
...
)
m
.
rules
=
append
(
m
.
rules
[
:
i
],
m
.
rules
[
i
+
1
:
]
...
)
...
@@ -750,6 +827,158 @@ func TestErrorPassthroughRule_Validate(t *testing.T) {
...
@@ -750,6 +827,158 @@ func TestErrorPassthroughRule_Validate(t *testing.T) {
}
}
}
}
// =============================================================================
// 测试写路径缓存刷新(Create/Update/Delete)
// =============================================================================
func
TestCreate_ForceRefreshCacheAfterWrite
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
staleRule
:=
newPassthroughRuleForWritePathTest
(
99
,
"service temporarily unavailable after multiple"
,
"旧缓存消息"
)
repo
:=
&
mockErrorPassthroughRepo
{
rules
:
[]
*
model
.
ErrorPassthroughRule
{}}
cache
:=
newMockErrorPassthroughCache
([]
*
model
.
ErrorPassthroughRule
{
staleRule
},
true
)
svc
:=
&
ErrorPassthroughService
{
repo
:
repo
,
cache
:
cache
}
svc
.
setLocalCache
([]
*
model
.
ErrorPassthroughRule
{
staleRule
})
newRule
:=
newPassthroughRuleForWritePathTest
(
0
,
"service temporarily unavailable after multiple"
,
"上游请求失败"
)
created
,
err
:=
svc
.
Create
(
ctx
,
newRule
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
created
)
body
:=
[]
byte
(
`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`
)
matched
:=
svc
.
MatchRule
(
"anthropic"
,
503
,
body
)
require
.
NotNil
(
t
,
matched
)
assert
.
Equal
(
t
,
created
.
ID
,
matched
.
ID
)
if
assert
.
NotNil
(
t
,
matched
.
CustomMessage
)
{
assert
.
Equal
(
t
,
"上游请求失败"
,
*
matched
.
CustomMessage
)
}
assert
.
Equal
(
t
,
0
,
cache
.
getCalled
,
"写路径刷新不应依赖 cache.Get"
)
assert
.
Equal
(
t
,
1
,
cache
.
invalidateCalled
)
assert
.
Equal
(
t
,
1
,
cache
.
setCalled
)
assert
.
Equal
(
t
,
1
,
cache
.
notifyCalled
)
}
func
TestUpdate_ForceRefreshCacheAfterWrite
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
originalRule
:=
newPassthroughRuleForWritePathTest
(
1
,
"old keyword"
,
"旧消息"
)
repo
:=
&
mockErrorPassthroughRepo
{
rules
:
[]
*
model
.
ErrorPassthroughRule
{
originalRule
}}
cache
:=
newMockErrorPassthroughCache
([]
*
model
.
ErrorPassthroughRule
{
originalRule
},
true
)
svc
:=
&
ErrorPassthroughService
{
repo
:
repo
,
cache
:
cache
}
svc
.
setLocalCache
([]
*
model
.
ErrorPassthroughRule
{
originalRule
})
updatedRule
:=
newPassthroughRuleForWritePathTest
(
1
,
"new keyword"
,
"新消息"
)
_
,
err
:=
svc
.
Update
(
ctx
,
updatedRule
)
require
.
NoError
(
t
,
err
)
oldBody
:=
[]
byte
(
`{"message":"old keyword"}`
)
oldMatched
:=
svc
.
MatchRule
(
"anthropic"
,
503
,
oldBody
)
assert
.
Nil
(
t
,
oldMatched
,
"更新后旧关键词不应继续命中"
)
newBody
:=
[]
byte
(
`{"message":"new keyword"}`
)
newMatched
:=
svc
.
MatchRule
(
"anthropic"
,
503
,
newBody
)
require
.
NotNil
(
t
,
newMatched
)
if
assert
.
NotNil
(
t
,
newMatched
.
CustomMessage
)
{
assert
.
Equal
(
t
,
"新消息"
,
*
newMatched
.
CustomMessage
)
}
assert
.
Equal
(
t
,
0
,
cache
.
getCalled
,
"写路径刷新不应依赖 cache.Get"
)
assert
.
Equal
(
t
,
1
,
cache
.
invalidateCalled
)
assert
.
Equal
(
t
,
1
,
cache
.
setCalled
)
assert
.
Equal
(
t
,
1
,
cache
.
notifyCalled
)
}
func
TestDelete_ForceRefreshCacheAfterWrite
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
rule
:=
newPassthroughRuleForWritePathTest
(
1
,
"to be deleted"
,
"删除前消息"
)
repo
:=
&
mockErrorPassthroughRepo
{
rules
:
[]
*
model
.
ErrorPassthroughRule
{
rule
}}
cache
:=
newMockErrorPassthroughCache
([]
*
model
.
ErrorPassthroughRule
{
rule
},
true
)
svc
:=
&
ErrorPassthroughService
{
repo
:
repo
,
cache
:
cache
}
svc
.
setLocalCache
([]
*
model
.
ErrorPassthroughRule
{
rule
})
err
:=
svc
.
Delete
(
ctx
,
1
)
require
.
NoError
(
t
,
err
)
body
:=
[]
byte
(
`{"message":"to be deleted"}`
)
matched
:=
svc
.
MatchRule
(
"anthropic"
,
503
,
body
)
assert
.
Nil
(
t
,
matched
,
"删除后规则不应再命中"
)
assert
.
Equal
(
t
,
0
,
cache
.
getCalled
,
"写路径刷新不应依赖 cache.Get"
)
assert
.
Equal
(
t
,
1
,
cache
.
invalidateCalled
)
assert
.
Equal
(
t
,
1
,
cache
.
setCalled
)
assert
.
Equal
(
t
,
1
,
cache
.
notifyCalled
)
}
func
TestNewService_StartupReloadFromDBToHealStaleCache
(
t
*
testing
.
T
)
{
staleRule
:=
newPassthroughRuleForWritePathTest
(
99
,
"stale keyword"
,
"旧缓存消息"
)
latestRule
:=
newPassthroughRuleForWritePathTest
(
1
,
"fresh keyword"
,
"最新消息"
)
repo
:=
&
mockErrorPassthroughRepo
{
rules
:
[]
*
model
.
ErrorPassthroughRule
{
latestRule
}}
cache
:=
newMockErrorPassthroughCache
([]
*
model
.
ErrorPassthroughRule
{
staleRule
},
true
)
svc
:=
NewErrorPassthroughService
(
repo
,
cache
)
matchedFresh
:=
svc
.
MatchRule
(
"anthropic"
,
503
,
[]
byte
(
`{"message":"fresh keyword"}`
))
require
.
NotNil
(
t
,
matchedFresh
)
assert
.
Equal
(
t
,
int64
(
1
),
matchedFresh
.
ID
)
matchedStale
:=
svc
.
MatchRule
(
"anthropic"
,
503
,
[]
byte
(
`{"message":"stale keyword"}`
))
assert
.
Nil
(
t
,
matchedStale
,
"启动后应以 DB 最新规则覆盖旧缓存"
)
assert
.
Equal
(
t
,
0
,
cache
.
getCalled
,
"启动强制 DB 刷新不应依赖 cache.Get"
)
assert
.
Equal
(
t
,
1
,
cache
.
setCalled
,
"启动后应回写缓存,覆盖陈旧缓存"
)
}
func
TestUpdate_RefreshFailureShouldNotKeepStaleEnabledRule
(
t
*
testing
.
T
)
{
ctx
:=
context
.
Background
()
staleRule
:=
newPassthroughRuleForWritePathTest
(
1
,
"service temporarily unavailable after multiple"
,
"旧缓存消息"
)
repo
:=
&
mockErrorPassthroughRepo
{
rules
:
[]
*
model
.
ErrorPassthroughRule
{
staleRule
},
listErr
:
errors
.
New
(
"db list failed"
),
}
cache
:=
newMockErrorPassthroughCache
([]
*
model
.
ErrorPassthroughRule
{
staleRule
},
true
)
svc
:=
&
ErrorPassthroughService
{
repo
:
repo
,
cache
:
cache
}
svc
.
setLocalCache
([]
*
model
.
ErrorPassthroughRule
{
staleRule
})
disabledRule
:=
*
staleRule
disabledRule
.
Enabled
=
false
_
,
err
:=
svc
.
Update
(
ctx
,
&
disabledRule
)
require
.
NoError
(
t
,
err
)
body
:=
[]
byte
(
`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`
)
matched
:=
svc
.
MatchRule
(
"anthropic"
,
503
,
body
)
assert
.
Nil
(
t
,
matched
,
"刷新失败时不应继续命中旧的启用规则"
)
svc
.
localCacheMu
.
RLock
()
assert
.
Nil
(
t
,
svc
.
localCache
,
"刷新失败后应清空本地缓存,避免误命中"
)
svc
.
localCacheMu
.
RUnlock
()
}
func
newPassthroughRuleForWritePathTest
(
id
int64
,
keyword
,
customMsg
string
)
*
model
.
ErrorPassthroughRule
{
responseCode
:=
503
rule
:=
&
model
.
ErrorPassthroughRule
{
ID
:
id
,
Name
:
"write-path-cache-refresh"
,
Enabled
:
true
,
Priority
:
1
,
ErrorCodes
:
[]
int
{
503
},
Keywords
:
[]
string
{
keyword
},
MatchMode
:
model
.
MatchModeAll
,
PassthroughCode
:
false
,
ResponseCode
:
&
responseCode
,
PassthroughBody
:
false
,
CustomMessage
:
&
customMsg
,
}
return
rule
}
// Helper functions
// Helper functions
func
testIntPtr
(
i
int
)
*
int
{
return
&
i
}
func
testIntPtr
(
i
int
)
*
int
{
return
&
i
}
func
testStrPtr
(
s
string
)
*
string
{
return
&
s
}
func
testStrPtr
(
s
string
)
*
string
{
return
&
s
}
backend/internal/service/gateway_service.go
View file @
7b156489
...
@@ -3563,6 +3563,34 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
...
@@ -3563,6 +3563,34 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
)
)
}
}
// 非 failover 错误也支持错误透传规则匹配。
if
status
,
errType
,
errMsg
,
matched
:=
applyErrorPassthroughRule
(
c
,
account
.
Platform
,
resp
.
StatusCode
,
body
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed"
,
);
matched
{
c
.
JSON
(
status
,
gin
.
H
{
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
errMsg
,
},
})
summary
:=
upstreamMsg
if
summary
==
""
{
summary
=
errMsg
}
if
summary
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream error: %d (passthrough rule matched)"
,
resp
.
StatusCode
)
}
return
nil
,
fmt
.
Errorf
(
"upstream error: %d (passthrough rule matched) message=%s"
,
resp
.
StatusCode
,
summary
)
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var
errType
,
errMsg
string
var
errType
,
errMsg
string
var
statusCode
int
var
statusCode
int
...
@@ -3694,6 +3722,33 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
...
@@ -3694,6 +3722,33 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
)
)
}
}
if
status
,
errType
,
errMsg
,
matched
:=
applyErrorPassthroughRule
(
c
,
account
.
Platform
,
resp
.
StatusCode
,
respBody
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed after retries"
,
);
matched
{
c
.
JSON
(
status
,
gin
.
H
{
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
errMsg
,
},
})
summary
:=
upstreamMsg
if
summary
==
""
{
summary
=
errMsg
}
if
summary
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream error: %d (retries exhausted, passthrough rule matched)"
,
resp
.
StatusCode
)
}
return
nil
,
fmt
.
Errorf
(
"upstream error: %d (retries exhausted, passthrough rule matched) message=%s"
,
resp
.
StatusCode
,
summary
)
}
// 返回统一的重试耗尽错误响应
// 返回统一的重试耗尽错误响应
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
c
.
JSON
(
http
.
StatusBadGateway
,
gin
.
H
{
"type"
:
"error"
,
"type"
:
"error"
,
...
...
backend/internal/service/gemini_messages_compat_service.go
View file @
7b156489
...
@@ -1498,6 +1498,28 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
...
@@ -1498,6 +1498,28 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
log
.
Printf
(
"[Gemini] upstream error %d: %s"
,
upstreamStatus
,
truncateForLog
(
body
,
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
))
log
.
Printf
(
"[Gemini] upstream error %d: %s"
,
upstreamStatus
,
truncateForLog
(
body
,
s
.
cfg
.
Gateway
.
LogUpstreamErrorBodyMaxBytes
))
}
}
if
status
,
errType
,
errMsg
,
matched
:=
applyErrorPassthroughRule
(
c
,
PlatformGemini
,
upstreamStatus
,
body
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed"
,
);
matched
{
c
.
JSON
(
status
,
gin
.
H
{
"type"
:
"error"
,
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
errMsg
},
})
if
upstreamMsg
==
""
{
upstreamMsg
=
errMsg
}
if
upstreamMsg
==
""
{
return
fmt
.
Errorf
(
"upstream error: %d (passthrough rule matched)"
,
upstreamStatus
)
}
return
fmt
.
Errorf
(
"upstream error: %d (passthrough rule matched) message=%s"
,
upstreamStatus
,
upstreamMsg
)
}
var
statusCode
int
var
statusCode
int
var
errType
,
errMsg
string
var
errType
,
errMsg
string
...
...
backend/internal/service/openai_gateway_service.go
View file @
7b156489
...
@@ -1087,6 +1087,30 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
...
@@ -1087,6 +1087,30 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
)
)
}
}
if
status
,
errType
,
errMsg
,
matched
:=
applyErrorPassthroughRule
(
c
,
PlatformOpenAI
,
resp
.
StatusCode
,
body
,
http
.
StatusBadGateway
,
"upstream_error"
,
"Upstream request failed"
,
);
matched
{
c
.
JSON
(
status
,
gin
.
H
{
"error"
:
gin
.
H
{
"type"
:
errType
,
"message"
:
errMsg
,
},
})
if
upstreamMsg
==
""
{
upstreamMsg
=
errMsg
}
if
upstreamMsg
==
""
{
return
nil
,
fmt
.
Errorf
(
"upstream error: %d (passthrough rule matched)"
,
resp
.
StatusCode
)
}
return
nil
,
fmt
.
Errorf
(
"upstream error: %d (passthrough rule matched) message=%s"
,
resp
.
StatusCode
,
upstreamMsg
)
}
// Check custom error codes
// Check custom error codes
if
!
account
.
ShouldHandleErrorCode
(
resp
.
StatusCode
)
{
if
!
account
.
ShouldHandleErrorCode
(
resp
.
StatusCode
)
{
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
appendOpsUpstreamError
(
c
,
OpsUpstreamErrorEvent
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment