Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
a67d9337
Commit
a67d9337
authored
Feb 09, 2026
by
erio
Browse files
feat: integrate CheckErrorPolicy into Gemini error handling paths
parent
2f1182e8
Changes
2
Show whitespace changes
Inline
Side-by-side
backend/internal/service/gemini_error_policy_test.go
0 → 100644
View file @
a67d9337
//go:build unit
package
service
import
(
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// TestShouldFailoverGeminiUpstreamError — verifies the failover decision
// for the ErrorPolicyNone path (original logic preserved).
// ---------------------------------------------------------------------------
func
TestShouldFailoverGeminiUpstreamError
(
t
*
testing
.
T
)
{
svc
:=
&
GeminiMessagesCompatService
{}
tests
:=
[]
struct
{
name
string
statusCode
int
expected
bool
}{
{
"401_failover"
,
401
,
true
},
{
"403_failover"
,
403
,
true
},
{
"429_failover"
,
429
,
true
},
{
"529_failover"
,
529
,
true
},
{
"500_failover"
,
500
,
true
},
{
"502_failover"
,
502
,
true
},
{
"503_failover"
,
503
,
true
},
{
"400_no_failover"
,
400
,
false
},
{
"404_no_failover"
,
404
,
false
},
{
"422_no_failover"
,
422
,
false
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
got
:=
svc
.
shouldFailoverGeminiUpstreamError
(
tt
.
statusCode
)
require
.
Equal
(
t
,
tt
.
expected
,
got
)
})
}
}
// ---------------------------------------------------------------------------
// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works
// correctly for Gemini platform accounts (API Key type).
// ---------------------------------------------------------------------------
func
TestCheckErrorPolicy_GeminiAccounts
(
t
*
testing
.
T
)
{
tests
:=
[]
struct
{
name
string
account
*
Account
statusCode
int
body
[]
byte
expected
ErrorPolicyResult
}{
{
name
:
"gemini_apikey_custom_codes_hit"
,
account
:
&
Account
{
ID
:
100
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
),
float64
(
500
)},
},
},
statusCode
:
429
,
body
:
[]
byte
(
`{"error":"rate limited"}`
),
expected
:
ErrorPolicyMatched
,
},
{
name
:
"gemini_apikey_custom_codes_miss"
,
account
:
&
Account
{
ID
:
101
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
)},
},
},
statusCode
:
500
,
body
:
[]
byte
(
`{"error":"internal"}`
),
expected
:
ErrorPolicySkipped
,
},
{
name
:
"gemini_apikey_no_custom_codes_returns_none"
,
account
:
&
Account
{
ID
:
102
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
},
statusCode
:
500
,
body
:
[]
byte
(
`{"error":"internal"}`
),
expected
:
ErrorPolicyNone
,
},
{
name
:
"gemini_apikey_temp_unschedulable_hit"
,
account
:
&
Account
{
ID
:
103
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
[]
any
{
map
[
string
]
any
{
"error_code"
:
float64
(
503
),
"keywords"
:
[]
any
{
"overloaded"
},
"duration_minutes"
:
float64
(
10
),
},
},
},
},
statusCode
:
503
,
body
:
[]
byte
(
`overloaded service`
),
expected
:
ErrorPolicyTempUnscheduled
,
},
{
name
:
"gemini_custom_codes_override_temp_unschedulable"
,
account
:
&
Account
{
ID
:
104
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
503
)},
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
[]
any
{
map
[
string
]
any
{
"error_code"
:
float64
(
503
),
"keywords"
:
[]
any
{
"overloaded"
},
"duration_minutes"
:
float64
(
10
),
},
},
},
},
statusCode
:
503
,
body
:
[]
byte
(
`overloaded`
),
expected
:
ErrorPolicyMatched
,
// custom codes take precedence
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
errorPolicyRepoStub
{}
svc
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
result
:=
svc
.
CheckErrorPolicy
(
context
.
Background
(),
tt
.
account
,
tt
.
statusCode
,
tt
.
body
)
require
.
Equal
(
t
,
tt
.
expected
,
result
)
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling
// paths produce the correct behavior for each ErrorPolicyResult.
//
// These tests simulate the inline error policy switch in handleClaudeCompat
// and forwardNativeGemini by calling the same methods in the same order.
// ---------------------------------------------------------------------------
func
TestGeminiErrorPolicyIntegration
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
tests
:=
[]
struct
{
name
string
account
*
Account
statusCode
int
respBody
[]
byte
expectFailover
bool
// expect UpstreamFailoverError
expectHandleError
bool
// expect handleGeminiUpstreamError to be called
expectShouldFailover
bool
// for None path, whether shouldFailover triggers
}{
{
name
:
"custom_codes_matched_429_failover"
,
account
:
&
Account
{
ID
:
200
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
)},
},
},
statusCode
:
429
,
respBody
:
[]
byte
(
`{"error":"rate limited"}`
),
expectFailover
:
true
,
expectHandleError
:
true
,
},
{
name
:
"custom_codes_skipped_500_no_failover"
,
account
:
&
Account
{
ID
:
201
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
)},
},
},
statusCode
:
500
,
respBody
:
[]
byte
(
`{"error":"internal"}`
),
expectFailover
:
false
,
expectHandleError
:
false
,
},
{
name
:
"temp_unschedulable_matched_failover"
,
account
:
&
Account
{
ID
:
202
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"temp_unschedulable_enabled"
:
true
,
"temp_unschedulable_rules"
:
[]
any
{
map
[
string
]
any
{
"error_code"
:
float64
(
503
),
"keywords"
:
[]
any
{
"overloaded"
},
"duration_minutes"
:
float64
(
10
),
},
},
},
},
statusCode
:
503
,
respBody
:
[]
byte
(
`overloaded`
),
expectFailover
:
true
,
expectHandleError
:
true
,
},
{
name
:
"no_policy_429_failover_via_shouldFailover"
,
account
:
&
Account
{
ID
:
203
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
},
statusCode
:
429
,
respBody
:
[]
byte
(
`{"error":"rate limited"}`
),
expectFailover
:
true
,
expectHandleError
:
true
,
expectShouldFailover
:
true
,
},
{
name
:
"no_policy_400_no_failover"
,
account
:
&
Account
{
ID
:
204
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
},
statusCode
:
400
,
respBody
:
[]
byte
(
`{"error":"bad request"}`
),
expectFailover
:
false
,
expectHandleError
:
true
,
},
}
for
_
,
tt
:=
range
tests
{
t
.
Run
(
tt
.
name
,
func
(
t
*
testing
.
T
)
{
repo
:=
&
geminiErrorPolicyRepo
{}
rlSvc
:=
NewRateLimitService
(
repo
,
nil
,
&
config
.
Config
{},
nil
,
nil
)
svc
:=
&
GeminiMessagesCompatService
{
accountRepo
:
repo
,
rateLimitService
:
rlSvc
,
}
writer
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
writer
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/v1/messages"
,
nil
)
// Simulate the Claude compat error handling path (same logic as native).
// This mirrors the inline switch in handleClaudeCompat.
var
handleErrorCalled
bool
var
gotFailover
bool
ctx
:=
context
.
Background
()
statusCode
:=
tt
.
statusCode
respBody
:=
tt
.
respBody
account
:=
tt
.
account
headers
:=
http
.
Header
{}
if
svc
.
rateLimitService
!=
nil
{
switch
svc
.
rateLimitService
.
CheckErrorPolicy
(
ctx
,
account
,
statusCode
,
respBody
)
{
case
ErrorPolicySkipped
:
// Skipped → return error directly (no handleGeminiUpstreamError, no failover)
gotFailover
=
false
handleErrorCalled
=
false
goto
verify
case
ErrorPolicyMatched
,
ErrorPolicyTempUnscheduled
:
svc
.
handleGeminiUpstreamError
(
ctx
,
account
,
statusCode
,
headers
,
respBody
)
handleErrorCalled
=
true
gotFailover
=
true
goto
verify
}
}
// ErrorPolicyNone → original logic
svc
.
handleGeminiUpstreamError
(
ctx
,
account
,
statusCode
,
headers
,
respBody
)
handleErrorCalled
=
true
if
svc
.
shouldFailoverGeminiUpstreamError
(
statusCode
)
{
gotFailover
=
true
}
verify
:
require
.
Equal
(
t
,
tt
.
expectFailover
,
gotFailover
,
"failover mismatch"
)
require
.
Equal
(
t
,
tt
.
expectHandleError
,
handleErrorCalled
,
"handleGeminiUpstreamError call mismatch"
)
if
tt
.
expectShouldFailover
{
require
.
True
(
t
,
svc
.
shouldFailoverGeminiUpstreamError
(
statusCode
),
"shouldFailoverGeminiUpstreamError should return true for status %d"
,
statusCode
)
}
})
}
}
// ---------------------------------------------------------------------------
// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety
// ---------------------------------------------------------------------------
func
TestGeminiErrorPolicy_NilRateLimitService
(
t
*
testing
.
T
)
{
svc
:=
&
GeminiMessagesCompatService
{
rateLimitService
:
nil
,
}
// When rateLimitService is nil, error policy is skipped → falls through to
// shouldFailoverGeminiUpstreamError (original logic).
// Verify this doesn't panic and follows expected behavior.
ctx
:=
context
.
Background
()
account
:=
&
Account
{
ID
:
300
,
Type
:
AccountTypeAPIKey
,
Platform
:
PlatformGemini
,
Credentials
:
map
[
string
]
any
{
"custom_error_codes_enabled"
:
true
,
"custom_error_codes"
:
[]
any
{
float64
(
429
)},
},
}
// The nil check should prevent CheckErrorPolicy from being called
if
svc
.
rateLimitService
!=
nil
{
t
.
Fatal
(
"rateLimitService should be nil for this test"
)
}
// shouldFailoverGeminiUpstreamError still works
require
.
True
(
t
,
svc
.
shouldFailoverGeminiUpstreamError
(
429
))
require
.
False
(
t
,
svc
.
shouldFailoverGeminiUpstreamError
(
400
))
// handleGeminiUpstreamError should not panic with nil rateLimitService
require
.
NotPanics
(
t
,
func
()
{
svc
.
handleGeminiUpstreamError
(
ctx
,
account
,
500
,
http
.
Header
{},
[]
byte
(
`error`
))
})
}
// ---------------------------------------------------------------------------
// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error
// policy tests. Embeds mockAccountRepoForGemini and adds tracking.
// ---------------------------------------------------------------------------
type
geminiErrorPolicyRepo
struct
{
mockAccountRepoForGemini
setErrorCalls
int
setRateLimitedCalls
int
setTempCalls
int
}
func
(
r
*
geminiErrorPolicyRepo
)
SetError
(
_
context
.
Context
,
_
int64
,
_
string
)
error
{
r
.
setErrorCalls
++
return
nil
}
func
(
r
*
geminiErrorPolicyRepo
)
SetRateLimited
(
_
context
.
Context
,
_
int64
,
_
time
.
Time
)
error
{
r
.
setRateLimitedCalls
++
return
nil
}
func
(
r
*
geminiErrorPolicyRepo
)
SetTempUnschedulable
(
_
context
.
Context
,
_
int64
,
_
time
.
Time
,
_
string
)
error
{
r
.
setTempCalls
++
return
nil
}
backend/internal/service/gemini_messages_compat_service.go
View file @
a67d9337
...
@@ -831,12 +831,17 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
...
@@ -831,12 +831,17 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if
resp
.
StatusCode
>=
400
{
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
tempMatched
:=
false
// 统一错误策略:自定义错误码 + 临时不可调度
if
s
.
rateLimitService
!=
nil
{
if
s
.
rateLimitService
!=
nil
{
tempMatched
=
s
.
rateLimitService
.
HandleTempUnschedulable
(
ctx
,
account
,
resp
.
StatusCode
,
respBody
)
switch
s
.
rateLimitService
.
CheckErrorPolicy
(
ctx
,
account
,
resp
.
StatusCode
,
respBody
)
{
case
ErrorPolicySkipped
:
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
if
upstreamReqID
==
""
{
upstreamReqID
=
resp
.
Header
.
Get
(
"x-goog-request-id"
)
}
}
return
nil
,
s
.
writeGeminiMappedError
(
c
,
account
,
resp
.
StatusCode
,
upstreamReqID
,
respBody
)
case
ErrorPolicyMatched
,
ErrorPolicyTempUnscheduled
:
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
if
tempMatched
{
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
if
upstreamReqID
==
""
{
if
upstreamReqID
==
""
{
upstreamReqID
=
resp
.
Header
.
Get
(
"x-goog-request-id"
)
upstreamReqID
=
resp
.
Header
.
Get
(
"x-goog-request-id"
)
...
@@ -863,6 +868,10 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
...
@@ -863,6 +868,10 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
})
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
,
ResponseBody
:
respBody
}
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
,
ResponseBody
:
respBody
}
}
}
}
// ErrorPolicyNone → 原有逻辑
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
if
s
.
shouldFailoverGeminiUpstreamError
(
resp
.
StatusCode
)
{
if
s
.
shouldFailoverGeminiUpstreamError
(
resp
.
StatusCode
)
{
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
upstreamReqID
:=
resp
.
Header
.
Get
(
requestIDHeader
)
if
upstreamReqID
==
""
{
if
upstreamReqID
==
""
{
...
@@ -1249,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
...
@@ -1249,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if
resp
.
StatusCode
>=
400
{
if
resp
.
StatusCode
>=
400
{
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
respBody
,
_
:=
io
.
ReadAll
(
io
.
LimitReader
(
resp
.
Body
,
2
<<
20
))
tempMatched
:=
false
if
s
.
rateLimitService
!=
nil
{
tempMatched
=
s
.
rateLimitService
.
HandleTempUnschedulable
(
ctx
,
account
,
resp
.
StatusCode
,
respBody
)
}
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
// This avoids Gemini SDKs failing hard during preflight token counting.
// This avoids Gemini SDKs failing hard during preflight token counting.
// Checked before error policy so it always works regardless of custom error codes.
if
action
==
"countTokens"
&&
isOAuth
&&
isGeminiInsufficientScope
(
resp
.
Header
,
respBody
)
{
if
action
==
"countTokens"
&&
isOAuth
&&
isGeminiInsufficientScope
(
resp
.
Header
,
respBody
)
{
estimated
:=
estimateGeminiCountTokens
(
body
)
estimated
:=
estimateGeminiCountTokens
(
body
)
c
.
JSON
(
http
.
StatusOK
,
map
[
string
]
any
{
"totalTokens"
:
estimated
})
c
.
JSON
(
http
.
StatusOK
,
map
[
string
]
any
{
"totalTokens"
:
estimated
})
...
@@ -1270,7 +1274,19 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
...
@@ -1270,7 +1274,19 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
},
nil
},
nil
}
}
if
tempMatched
{
// 统一错误策略:自定义错误码 + 临时不可调度
if
s
.
rateLimitService
!=
nil
{
switch
s
.
rateLimitService
.
CheckErrorPolicy
(
ctx
,
account
,
resp
.
StatusCode
,
respBody
)
{
case
ErrorPolicySkipped
:
respBody
=
unwrapIfNeeded
(
isOAuth
,
respBody
)
contentType
:=
resp
.
Header
.
Get
(
"Content-Type"
)
if
contentType
==
""
{
contentType
=
"application/json"
}
c
.
Data
(
resp
.
StatusCode
,
contentType
,
respBody
)
return
nil
,
fmt
.
Errorf
(
"gemini upstream error: %d (skipped by error policy)"
,
resp
.
StatusCode
)
case
ErrorPolicyMatched
,
ErrorPolicyTempUnscheduled
:
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
evBody
:=
unwrapIfNeeded
(
isOAuth
,
respBody
)
evBody
:=
unwrapIfNeeded
(
isOAuth
,
respBody
)
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
evBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
evBody
))
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
upstreamMsg
=
sanitizeUpstreamErrorMessage
(
upstreamMsg
)
...
@@ -1294,6 +1310,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
...
@@ -1294,6 +1310,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
})
})
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
,
ResponseBody
:
respBody
}
return
nil
,
&
UpstreamFailoverError
{
StatusCode
:
resp
.
StatusCode
,
ResponseBody
:
respBody
}
}
}
}
// ErrorPolicyNone → 原有逻辑
s
.
handleGeminiUpstreamError
(
ctx
,
account
,
resp
.
StatusCode
,
resp
.
Header
,
respBody
)
if
s
.
shouldFailoverGeminiUpstreamError
(
resp
.
StatusCode
)
{
if
s
.
shouldFailoverGeminiUpstreamError
(
resp
.
StatusCode
)
{
evBody
:=
unwrapIfNeeded
(
isOAuth
,
respBody
)
evBody
:=
unwrapIfNeeded
(
isOAuth
,
respBody
)
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
evBody
))
upstreamMsg
:=
strings
.
TrimSpace
(
extractUpstreamErrorMessage
(
evBody
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment