Commit 94895314 authored by alfadb's avatar alfadb
Browse files

fix(gateway): return 404 instead of fake 200 for unsupported count_tokens endpoint

PR #635 returned HTTP 200 with {"input_tokens": 0} when upstream doesn't
support count_tokens (404). This caused Claude Code CLI to trust the zero
value, believing context uses 0 tokens, so auto-compression never triggers.

Fix: return 404 with proper error body so CLI falls back to its local
tokenizer for accurate estimation. Return nil (not error) to avoid
polluting ops error metrics with expected 404s.

Affected paths:
- Passthrough APIKey accounts: upstream 404 now passed through as 404
- Antigravity accounts: same fix (was also returning fake 200)
parent 4ac57b4e
...@@ -262,44 +262,44 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo ...@@ -262,44 +262,44 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
require.Empty(t, rec.Header().Get("Set-Cookie")) require.Empty(t, rec.Header().Get("Set-Cookie"))
} }
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokensFallbackOn404(t *testing.T) { func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
tests := []struct { tests := []struct {
name string name string
statusCode int statusCode int
respBody string respBody string
wantFallback bool wantPassthrough bool
}{ }{
{ {
name: "404 endpoint not found triggers fallback", name: "404 endpoint not found passes through as 404",
statusCode: http.StatusNotFound, statusCode: http.StatusNotFound,
respBody: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`, respBody: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
wantFallback: true, wantPassthrough: true,
}, },
{ {
name: "404 generic not found triggers fallback", name: "404 generic not found passes through as 404",
statusCode: http.StatusNotFound, statusCode: http.StatusNotFound,
respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`, respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
wantFallback: true, wantPassthrough: true,
}, },
{ {
name: "400 Invalid URL does not fallback", name: "400 Invalid URL does not passthrough",
statusCode: http.StatusBadRequest, statusCode: http.StatusBadRequest,
respBody: `{"error":{"message":"Invalid URL (POST /v1/messages/count_tokens)","type":"invalid_request_error"}}`, respBody: `{"error":{"message":"Invalid URL (POST /v1/messages/count_tokens)","type":"invalid_request_error"}}`,
wantFallback: false, wantPassthrough: false,
}, },
{ {
name: "400 model error does not fallback", name: "400 model error does not passthrough",
statusCode: http.StatusBadRequest, statusCode: http.StatusBadRequest,
respBody: `{"error":{"message":"model not found: claude-unknown","type":"invalid_request_error"}}`, respBody: `{"error":{"message":"model not found: claude-unknown","type":"invalid_request_error"}}`,
wantFallback: false, wantPassthrough: false,
}, },
{ {
name: "500 internal error does not fallback", name: "500 internal error does not passthrough",
statusCode: http.StatusInternalServerError, statusCode: http.StatusInternalServerError,
respBody: `{"error":{"message":"internal error","type":"api_error"}}`, respBody: `{"error":{"message":"internal error","type":"api_error"}}`,
wantFallback: false, wantPassthrough: false,
}, },
} }
...@@ -345,10 +345,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokensFallbackOn404(t *t ...@@ -345,10 +345,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokensFallbackOn404(t *t
err := svc.ForwardCountTokens(context.Background(), c, account, parsed) err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
if tt.wantFallback { if tt.wantPassthrough {
// 404 透传:返回 nil(不记录为错误),但 HTTP 状态码是 404
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, http.StatusNotFound, rec.Code)
require.JSONEq(t, `{"input_tokens":0}`, rec.Body.String())
} else { } else {
require.Error(t, err) require.Error(t, err)
require.Equal(t, tt.statusCode, rec.Code) require.Equal(t, tt.statusCode, rec.Code)
......
...@@ -6015,9 +6015,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, ...@@ -6015,9 +6015,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
} }
// Antigravity 账户不支持 count_tokens 转发,直接返回空值 // Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。
// 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。
if account.Platform == PlatformAntigravity { if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 0}) s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
return nil return nil
} }
...@@ -6222,12 +6223,13 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex ...@@ -6222,12 +6223,13 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
// 中转站不支持 count_tokens 端点时(404),降级返回空值,客户端会 fallback 到本地估算。 // 中转站不支持 count_tokens 端点时(404),透传 404 让客户端 fallback 到本地估算。
// 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。
if resp.StatusCode == http.StatusNotFound { if resp.StatusCode == http.StatusNotFound {
logger.LegacyPrintf("service.gateway", logger.LegacyPrintf("service.gateway",
"[count_tokens] Upstream does not support count_tokens (404), returning fallback: account=%d name=%s msg=%s", "[count_tokens] Upstream does not support count_tokens (404), passing through: account=%d name=%s msg=%s",
account.ID, account.Name, truncateString(upstreamMsg, 512)) account.ID, account.Name, truncateString(upstreamMsg, 512))
c.JSON(http.StatusOK, gin.H{"input_tokens": 0}) s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported by upstream")
return nil return nil
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment