Commit 615557ec authored by gaoren002's avatar gaoren002
Browse files

fix(openai): avoid implicit image sticky sessions

parent c056db74
...@@ -117,12 +117,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { ...@@ -117,12 +117,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
return return
} }
sessionHash := "" sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body)
if parsed.Multipart {
sessionHash = h.gatewayService.GenerateSessionHashWithFallback(c, nil, parsed.StickySessionSeed())
} else {
sessionHash = h.gatewayService.GenerateSessionHash(c, body)
}
maxAccountSwitches := h.maxAccountSwitches maxAccountSwitches := h.maxAccountSwitches
switchCount := 0 switchCount := 0
......
...@@ -1125,6 +1125,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str ...@@ -1125,6 +1125,35 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str
return sessionID return sessionID
} }
func explicitOpenAISessionID(c *gin.Context, body []byte) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
return sessionID
}
// GenerateExplicitSessionHash generates a sticky-session hash only from explicit
// client session signals. It intentionally skips content-derived fallback and is
// used by stateless endpoints such as /v1/images.
func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string {
sessionID := explicitOpenAISessionID(c, body)
if sessionID == "" {
return ""
}
currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
attachOpenAILegacySessionHashToGin(c, legacyHash)
return currentHash
}
// GenerateSessionHash generates a sticky-session hash for OpenAI requests. // GenerateSessionHash generates a sticky-session hash for OpenAI requests.
// //
// Priority: // Priority:
...@@ -1137,13 +1166,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) ...@@ -1137,13 +1166,7 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
return "" return ""
} }
sessionID := strings.TrimSpace(c.GetHeader("session_id")) sessionID := explicitOpenAISessionID(c, body)
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
if sessionID == "" && len(body) > 0 { if sessionID == "" && len(body) > 0 {
sessionID = deriveOpenAIContentSessionSeed(body) sessionID = deriveOpenAIContentSessionSeed(body)
} }
......
...@@ -227,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t ...@@ -227,6 +227,41 @@ func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context())) require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
} }
func TestOpenAIGatewayService_GenerateExplicitSessionHash_SkipsContentFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := &OpenAIGatewayService{}
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat"}`)
t.Run("stateless image body stays unstuck", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
require.Empty(t, svc.GenerateExplicitSessionHash(c, body))
require.Empty(t, openAILegacySessionHashFromContext(c.Request.Context()))
})
t.Run("prompt_cache_key is explicit", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
got := svc.GenerateExplicitSessionHash(c, []byte(`{"model":"gpt-image-2","prompt_cache_key":"image-session"}`))
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("image-session")), got)
require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context()))
})
t.Run("header overrides body", func(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
c.Request.Header.Set("session_id", "header-session")
got := svc.GenerateExplicitSessionHash(c, []byte(`{"prompt_cache_key":"body-session"}`))
require.Equal(t, fmt.Sprintf("%016x", xxhash.Sum64String("header-session")), got)
})
}
func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) { func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment