Commit 2c667a15 authored by Ethan0x0000's avatar Ethan0x0000
Browse files

fix(provider): retain upstream model for gemini compat and ws

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent

)
Co-authored-by: default avatarSisyphus <clio-agent@sisyphuslabs.ai>
parent bac40804
...@@ -1028,14 +1028,15 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex ...@@ -1028,14 +1028,15 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
} }
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
Stream: req.Stream, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: req.Stream,
FirstTokenMs: firstTokenMs, Duration: time.Since(startTime),
ImageCount: imageCount, FirstTokenMs: firstTokenMs,
ImageSize: imageSize, ImageCount: imageCount,
ImageSize: imageSize,
}, nil }, nil
} }
...@@ -1241,12 +1242,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -1241,12 +1242,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
estimated := estimateGeminiCountTokens(body) estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{ return &ForwardResult{
RequestID: "", RequestID: "",
Usage: ClaudeUsage{}, Usage: ClaudeUsage{},
Model: originalModel, Model: originalModel,
Stream: false, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: false,
FirstTokenMs: nil, Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil }, nil
} }
setOpsUpstreamError(c, 0, safeErr, "") setOpsUpstreamError(c, 0, safeErr, "")
...@@ -1310,12 +1312,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -1310,12 +1312,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
estimated := estimateGeminiCountTokens(body) estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{ return &ForwardResult{
RequestID: "", RequestID: "",
Usage: ClaudeUsage{}, Usage: ClaudeUsage{},
Model: originalModel, Model: originalModel,
Stream: false, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: false,
FirstTokenMs: nil, Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil }, nil
} }
// Final attempt: surface the upstream error body (passed through below) instead of a generic retry error. // Final attempt: surface the upstream error body (passed through below) instead of a generic retry error.
...@@ -1350,12 +1353,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -1350,12 +1353,13 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
estimated := estimateGeminiCountTokens(body) estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: ClaudeUsage{}, Usage: ClaudeUsage{},
Model: originalModel, Model: originalModel,
Stream: false, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: false,
FirstTokenMs: nil, Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil }, nil
} }
...@@ -1527,14 +1531,15 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. ...@@ -1527,14 +1531,15 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
} }
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
Stream: stream, UpstreamModel: mappedModel,
Duration: time.Since(startTime), Stream: stream,
FirstTokenMs: firstTokenMs, Duration: time.Since(startTime),
ImageCount: imageCount, FirstTokenMs: firstTokenMs,
ImageSize: imageSize, ImageCount: imageCount,
ImageSize: imageSize,
}, nil }, nil
} }
......
package service package service
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
...@@ -15,6 +16,30 @@ import ( ...@@ -15,6 +16,30 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type geminiCompatHTTPUpstreamStub struct {
response *http.Response
err error
calls int
lastReq *http.Request
}
func (s *geminiCompatHTTPUpstreamStub) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
s.calls++
s.lastReq = req
if s.err != nil {
return nil, s.err
}
if s.response == nil {
return nil, fmt.Errorf("missing stub response")
}
resp := *s.response
return &resp, nil
}
func (s *geminiCompatHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 // TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
tests := []struct { tests := []struct {
...@@ -170,6 +195,42 @@ func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLo ...@@ -170,6 +195,42 @@ func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLo
require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志") require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志")
} }
func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpstreamModel(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
httpStub := &geminiCompatHTTPUpstreamStub{
response: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"x-request-id": []string{"gemini-req-1"}},
Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}`)),
},
}
svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}}
account := &Account{
ID: 1,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-key",
"model_mapping": map[string]any{
"claude-sonnet-4": "claude-sonnet-4-20250514",
},
},
}
body := []byte(`{"model":"claude-sonnet-4","max_tokens":16,"messages":[{"role":"user","content":"hello"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "claude-sonnet-4", result.Model)
require.Equal(t, "claude-sonnet-4-20250514", result.UpstreamModel)
require.Equal(t, 1, httpStub.calls)
require.NotNil(t, httpStub.lastReq)
require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:")
}
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) { func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
claudeReq := map[string]any{ claudeReq := map[string]any{
"model": "claude-haiku-4-5-20251001", "model": "claude-haiku-4-5-20251001",
......
...@@ -2328,6 +2328,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( ...@@ -2328,6 +2328,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
RequestID: responseID, RequestID: responseID,
Usage: *usage, Usage: *usage,
Model: originalModel, Model: originalModel,
UpstreamModel: mappedModel,
ServiceTier: extractOpenAIServiceTier(reqBody), ServiceTier: extractOpenAIServiceTier(reqBody),
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
Stream: reqStream, Stream: reqStream,
...@@ -2945,6 +2946,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ...@@ -2945,6 +2946,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
RequestID: responseID, RequestID: responseID,
Usage: usage, Usage: usage,
Model: originalModel, Model: originalModel,
UpstreamModel: mappedModel,
ServiceTier: extractOpenAIServiceTierFromBody(payload), ServiceTier: extractOpenAIServiceTierFromBody(payload),
ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
Stream: reqStream, Stream: reqStream,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment