Commit 45456fa2 authored by erio's avatar erio
Browse files

fix: restore OAuth 401 temp-unschedulable for Gemini, update Antigravity tests

The 403 detection PR changed the 401 handler condition from
`account.Type == AccountTypeOAuth` to
`account.Type == AccountTypeOAuth && account.Platform == PlatformOpenAI`,
which accidentally excluded Gemini OAuth from the temp-unschedulable path.

Fix: use `!= PlatformAntigravity` instead, preserving Gemini behavior
while correctly excluding Antigravity (whose 401 is handled by
applyErrorPolicy's temp_unschedulable_rules).

Update tests to reflect Antigravity's new 401 semantics:
- HandleUpstreamError: Antigravity OAuth 401 now uses SetError
- CheckErrorPolicy: Antigravity 401 second hit stays TempUnscheduled
- DB fallback: split into Gemini (escalates) and Antigravity (stays temp)
parent 6344fa2a
...@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) { ...@@ -110,7 +110,9 @@ func TestCheckErrorPolicy(t *testing.T) {
expected: ErrorPolicyTempUnscheduled, expected: ErrorPolicyTempUnscheduled,
}, },
{ {
name: "temp_unschedulable_401_second_hit_upgrades_to_none", // Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制),
// second hit 仍然返回 TempUnscheduled。
name: "temp_unschedulable_401_second_hit_antigravity_stays_temp",
account: &Account{ account: &Account{
ID: 15, ID: 15,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
...@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) { ...@@ -129,7 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) {
}, },
statusCode: 401, statusCode: 401,
body: []byte(`unauthorized`), body: []byte(`unauthorized`),
expected: ErrorPolicyNone, expected: ErrorPolicyTempUnscheduled,
}, },
{ {
name: "temp_unschedulable_body_miss_returns_none", name: "temp_unschedulable_body_miss_returns_none",
......
...@@ -149,9 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -149,9 +149,9 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} }
// 其他 400 错误(如参数问题)不处理,不禁用账号 // 其他 400 错误(如参数问题)不处理,不禁用账号
case 401: case 401:
// OpenAI OAuth 账号在 401 错误时临时不可调度;其他平台 OAuth 账号保持原有 SetError 行为 // OAuth 账号在 401 错误时临时不可调度(给 token 刷新窗口);非 OAuth 账号保持原有 SetError 行为
// Antigravity 主流程不走此路径,其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制 // Antigravity 除外:其 401 由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制
if account.Type == AccountTypeOAuth && account.Platform == PlatformOpenAI { if account.Type == AccountTypeOAuth && account.Platform != PlatformAntigravity {
// 1. 失效缓存 // 1. 失效缓存
if s.tokenCacheInvalidator != nil { if s.tokenCacheInvalidator != nil {
if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil { if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
...@@ -183,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc ...@@ -183,7 +183,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
} }
shouldDisable = true shouldDisable = true
} else { } else {
// 非 OAuth 账号(APIKey):保持原有 SetError 行为 // 非 OAuth / Antigravity OAuth:保持 SetError 行为
msg := "Authentication failed (401): invalid or expired credentials" msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" { if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg msg = "Authentication failed (401): " + upstreamMsg
......
...@@ -27,34 +27,68 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e ...@@ -27,34 +27,68 @@ func (r *dbFallbackRepoStub) GetByID(ctx context.Context, id int64) (*Account, e
func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) { func TestCheckErrorPolicy_401_DBFallback_Escalates(t *testing.T) {
// Scenario: cache account has empty TempUnschedulableReason (cache miss), // Scenario: cache account has empty TempUnschedulableReason (cache miss),
// but DB account has a previous 401 record → should escalate to ErrorPolicyNone. // but DB account has a previous 401 record.
repo := &dbFallbackRepoStub{ // Non-Antigravity: should escalate to ErrorPolicyNone (second 401 = permanent error).
dbAccount: &Account{ // Antigravity: skips escalation logic (401 handled by applyErrorPolicy rules).
t.Run("gemini_escalates", func(t *testing.T) {
repo := &dbFallbackRepoStub{
dbAccount: &Account{
ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
},
}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 20, ID: 20,
TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, Type: AccountTypeOAuth,
}, Platform: PlatformGemini,
} TempUnschedulableReason: "",
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
},
},
}
account := &Account{ result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
ID: 20, require.Equal(t, ErrorPolicyNone, result, "gemini 401 with DB fallback showing previous 401 should escalate")
Type: AccountTypeOAuth, })
Platform: PlatformAntigravity,
TempUnschedulableReason: "", // cache miss — reason is empty t.Run("antigravity_stays_temp", func(t *testing.T) {
Credentials: map[string]any{ repo := &dbFallbackRepoStub{
"temp_unschedulable_enabled": true, dbAccount: &Account{
"temp_unschedulable_rules": []any{ ID: 20,
map[string]any{ TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`,
"error_code": float64(401), },
"keywords": []any{"unauthorized"}, }
"duration_minutes": float64(10), svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 20,
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
TempUnschedulableReason: "",
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": float64(401),
"keywords": []any{"unauthorized"},
"duration_minutes": float64(10),
},
}, },
}, },
}, }
}
result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`)) result := svc.CheckErrorPolicy(context.Background(), account, http.StatusUnauthorized, []byte(`unauthorized`))
require.Equal(t, ErrorPolicyNone, result, "401 with DB fallback showing previous 401 should escalate to ErrorPolicyNone") require.Equal(t, ErrorPolicyTempUnscheduled, result, "antigravity 401 skips escalation, stays temp-unscheduled")
})
} }
func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) { func TestCheckErrorPolicy_401_DBFallback_NoDBRecord_FirstHit(t *testing.T) {
......
...@@ -42,45 +42,56 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc ...@@ -42,45 +42,56 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc
} }
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) { func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
tests := []struct { t.Run("gemini", func(t *testing.T) {
name string repo := &rateLimitAccountRepoStub{}
platform string invalidator := &tokenCacheInvalidatorRecorder{}
}{ service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
{name: "gemini", platform: PlatformGemini}, service.SetTokenCacheInvalidator(invalidator)
{name: "antigravity", platform: PlatformAntigravity}, account := &Account{
} ID: 100,
Platform: PlatformGemini,
for _, tt := range tests { Type: AccountTypeOAuth,
t.Run(tt.name, func(t *testing.T) { Credentials: map[string]any{
repo := &rateLimitAccountRepoStub{} "temp_unschedulable_enabled": true,
invalidator := &tokenCacheInvalidatorRecorder{} "temp_unschedulable_rules": []any{
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) map[string]any{
service.SetTokenCacheInvalidator(invalidator) "error_code": 401,
account := &Account{ "keywords": []any{"unauthorized"},
ID: 100, "duration_minutes": 30,
Platform: tt.platform, "description": "custom rule",
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
}, },
}, },
} },
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls) require.True(t, shouldDisable)
require.Equal(t, 1, repo.tempCalls) require.Equal(t, 0, repo.setErrorCalls)
require.Len(t, invalidator.accounts, 1) require.Equal(t, 1, repo.tempCalls)
}) require.Len(t, invalidator.accounts, 1)
} })
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
// HandleUpstreamError 中走 SetError 路径。
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Empty(t, invalidator.accounts)
})
} }
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) { func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment