Commit bac40804 authored by Ethan0x0000's avatar Ethan0x0000
Browse files

fix(provider): preserve requested model in antigravity and sora

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent

)
Co-authored-by: default avatarSisyphus <clio-agent@sisyphuslabs.ai>
parent 4edcfe1f
...@@ -1742,7 +1742,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ...@@ -1742,7 +1742,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: billingModel, // 使用映射模型用于计费和日志 Model: originalModel,
UpstreamModel: billingModel,
Stream: claudeReq.Stream, Stream: claudeReq.Stream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -2435,7 +2436,8 @@ handleSuccess: ...@@ -2435,7 +2436,8 @@ handleSuccess:
return &ForwardResult{ return &ForwardResult{
RequestID: requestID, RequestID: requestID,
Usage: *usage, Usage: *usage,
Model: billingModel, Model: originalModel,
UpstreamModel: billingModel,
Stream: stream, Stream: stream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
......
...@@ -542,7 +542,8 @@ func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) { ...@@ -542,7 +542,8 @@ func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
result, err := svc.Forward(context.Background(), c, account, body, false) result, err := svc.Forward(context.Background(), c, account, body, false)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model) require.Equal(t, "claude-sonnet-4-5", result.Model)
require.Equal(t, mappedModel, result.UpstreamModel)
} }
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel // TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
...@@ -594,7 +595,8 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing ...@@ -594,7 +595,8 @@ func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false) result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model) require.Equal(t, "gemini-2.5-flash", result.Model)
require.Equal(t, mappedModel, result.UpstreamModel)
} }
func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) { func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignature(t *testing.T) {
...@@ -664,7 +666,8 @@ func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignatur ...@@ -664,7 +666,8 @@ func TestAntigravityGatewayService_ForwardGemini_RetriesCorruptedThoughtSignatur
result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false) result, err := svc.ForwardGemini(context.Background(), c, account, originalModel, "streamGenerateContent", true, body, false)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, mappedModel, result.Model) require.Equal(t, originalModel, result.Model)
require.Equal(t, mappedModel, result.UpstreamModel)
require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry") require.Len(t, upstream.requestBodies, 2, "signature error should trigger exactly one retry")
firstReq := string(upstream.requestBodies[0]) firstReq := string(upstream.requestBodies[0])
......
...@@ -148,10 +148,13 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun ...@@ -148,10 +148,13 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream)
return nil, errors.New("model is required") return nil, errors.New("model is required")
} }
originalModel := reqModel
mappedModel := account.GetMappedModel(reqModel) mappedModel := account.GetMappedModel(reqModel)
var upstreamModel string
if mappedModel != "" && mappedModel != reqModel { if mappedModel != "" && mappedModel != reqModel {
reqModel = mappedModel reqModel = mappedModel
upstreamModel = mappedModel
} }
modelCfg, ok := GetSoraModelConfig(reqModel) modelCfg, ok := GetSoraModelConfig(reqModel)
...@@ -214,7 +217,8 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun ...@@ -214,7 +217,8 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
} }
return &ForwardResult{ return &ForwardResult{
RequestID: "", RequestID: "",
Model: reqModel, Model: originalModel,
UpstreamModel: upstreamModel,
Stream: clientStream, Stream: clientStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -270,7 +274,8 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun ...@@ -270,7 +274,8 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
} }
return &ForwardResult{ return &ForwardResult{
RequestID: "", RequestID: "",
Model: reqModel, Model: originalModel,
UpstreamModel: upstreamModel,
Stream: clientStream, Stream: clientStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
...@@ -420,7 +425,8 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun ...@@ -420,7 +425,8 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
return &ForwardResult{ return &ForwardResult{
RequestID: taskID, RequestID: taskID,
Model: reqModel, Model: originalModel,
UpstreamModel: upstreamModel,
Stream: clientStream, Stream: clientStream,
Duration: time.Since(startTime), Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs, FirstTokenMs: firstTokenMs,
......
...@@ -144,6 +144,11 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { ...@@ -144,6 +144,11 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
ID: 1, ID: 1,
Platform: PlatformSora, Platform: PlatformSora,
Status: StatusActive, Status: StatusActive,
Credentials: map[string]any{
"model_mapping": map[string]any{
"prompt-enhance-short-10s": "prompt-enhance-short-15s",
},
},
} }
body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
...@@ -152,6 +157,7 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { ...@@ -152,6 +157,7 @@ func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
require.NotNil(t, result) require.NotNil(t, result)
require.Equal(t, "prompt", result.MediaType) require.Equal(t, "prompt", result.MediaType)
require.Equal(t, "prompt-enhance-short-10s", result.Model) require.Equal(t, "prompt-enhance-short-10s", result.Model)
require.Equal(t, "prompt-enhance-short-15s", result.UpstreamModel)
} }
func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment