"git@web.lueluesay.top:chenxi/sub2api.git" did not exist on "c95a8649759efab218beb7b9acefb52e2674ffc3"
Commit 45456fa2 authored by erio's avatar erio
Browse files

fix: restore OAuth 401 temp-unschedulable for Gemini, update Antigravity tests

The 403 detection PR changed the 401 handler condition from
`account.Type == AccountTypeOAuth` to
`account.Type == AccountTypeOAuth && account.Platform == PlatformOpenAI`,
which accidentally excluded Gemini OAuth from the temp-unschedulable path.

Fix: use `!= PlatformAntigravity` instead, preserving Gemini behavior
while correctly excluding Antigravity (whose 401 is handled by
applyErrorPolicy's temp_unschedulable_rules).

Update tests to reflect Antigravity's new 401 semantics:
- HandleUpstreamError: Antigravity OAuth 401 now uses SetError
- CheckErrorPolicy: Antigravity 401 second hit stays TempUnscheduled
- DB fallback: split into Gemini (escalates) and Antigravity (stays temp)
parent 6344fa2a
...@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) { ...@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) {
expected: ErrorPolicyTempUnscheduled, expected: ErrorPolicyTempUnscheduled,
}, },
{ {
name: "temp_unschedulable_401_second_hit_upgrades_to_none", // Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
// second hit 仍然返回 TempUnscheduled。
name: "temp_unschedulable_401_second_hit_antigravity_stays_temp",
account: &Account{ account: &Account{
ID: 15, ID: 15,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
...@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) { ...@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) {
}, },
statusCode: 401, statusCode: 401,
body: []byte(`unauthorized`), body: []byte(`unauthorized`),
expected: ErrorPolicyNone, expected: ErrorPolicyTempUnscheduled,
}, },
{ {
name: "temp_unschedulable_body_miss_returns_none", name: "temp_unschedulable_body_miss_returns_none",
......
...@@ -149,9 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -149,9 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} }
// 其他 400 错误(如参数问题)不处理,不禁用账号 // 其他 400 错误(如参数问题)不处理,不禁用账号
case 401: case 401:
// OpenAI OAuth 账号在 401 错误时临时不可调度;其他平台 OAuth 账号保持原有 SetError 行为 // OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为
// Antigravity 主流程不走此路径,其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制 // Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制
if account.Type == AccountTypeOAuth && account.Platform == PlatformOpenAI { if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity {
// 1. 失效缓存 // 1. 失效缓存
if s.tokenCacheInvalidator != nil { if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
...@@ -183,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -183,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} }
shouldDisable = true shouldDisable = true
} else { } else {
// 非 OAuth 账号(APIKey):保持原有 SetError 行为 // 非 OAuth / Antigravity OAuth:保持 SetError 行为
msg := "Authentication failed (401): invalid or expired credentials" msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" { if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg msg = "Authentication failed (401): " + upstreamMsg
......
...@@ -27,7 +27,40 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e ...@@ -27,7 +27,40 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) { func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
// Scenario: cache account has empty TempUnschedulableReason (cache miss), // Scenario: cache account has empty TempUnschedulableReason (cache miss),
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone. // but DB account has a previous 401 record.
// Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error).
// Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules).
t.Run("gemini_escalates", func(t *testing.T) {
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformGemini,
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate")
})
t.Run("antigravity_stays_temp", func(t *testing.T) {
repo := &dbFallbackRepoStub{ repo := &dbFallbackRepoStub{
dbAccount: &Account{ dbAccount: &Account{
ID: 20, ID: 20,
...@@ -40,7 +73,7 @@ func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) { ...@@ -40,7 +73,7 @@ func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
ID: 20, ID: 20,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
Platform: PlatformAntigravity, Platform: PlatformAntigravity,
TempUnschedulableReason: "", // cache miss — reason is empty TempUnschedulableReason: "",
Credentials: map[string]any{ Credentials: map[string]any{
"temp_unschedulable_enabled": true, "temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{ "temp_unschedulable_rules": []any{
...@@ -54,7 +87,8 @@ func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) { ...@@ -54,7 +87,8 @@ func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
} }
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone") require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled")
})
} }
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) { func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {
......
...@@ -42,23 +42,14 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc ...@@ -42,23 +42,14 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
} }
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) { func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
tests := []struct { t.Run("gemini", func(t *testing.T) {
name string
platform string
}{
{name: "gemini", platform: PlatformGemini},
{name: "antigravity", platform: PlatformAntigravity},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &rateLimitAccountRepoStub{} repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{} invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator) service.SetTokenCacheInvalidator(invalidator)
account := &Account{ account := &Account{
ID: 100, ID: 100,
Platform: tt.platform, Platform: PlatformGemini,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
Credentials: map[string]any{ Credentials: map[string]any{
"temp_unschedulable_enabled": true, "temp_unschedulable_enabled": true,
...@@ -80,7 +71,27 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t ...@@ -80,7 +71,27 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t
require.Equal(t, 1, repo.tempCalls) require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1) require.Len(t, invalidator.accounts, 1)
}) })
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
// HandleUpstreamError 中走 SetError 路径。
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
} }
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Empty(t, invalidator.accounts)
})
} }
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) { func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment