Unverified Commit f972a2fa authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1990 from haha1903/feat/zstd-request-decompression

feat(httputil): decode zstd/gzip/deflate request bodies
parents 55a7fa1e 798fd673
......@@ -2,8 +2,15 @@ package httputil
import (
"bytes"
"compress/gzip"
"compress/zlib"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/klauspost/compress/zstd"
)
const (
......@@ -11,7 +18,9 @@ const (
requestBodyReadMaxInitCap = 1 << 20
)
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based
// on content length, transparently decoding any Content-Encoding the upstream
// client used to compress the body (zstd, gzip, deflate).
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if req == nil || req.Body == nil {
return nil, nil
......@@ -33,5 +42,49 @@ func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if _, err := io.Copy(buf, req.Body); err != nil {
return nil, err
}
return buf.Bytes(), nil
raw := buf.Bytes()
enc := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Encoding")))
if enc == "" || enc == "identity" {
return raw, nil
}
decoded, err := decompressRequestBody(enc, raw)
if err != nil {
return nil, fmt.Errorf("decode Content-Encoding %q: %w", enc, err)
}
req.Header.Del("Content-Encoding")
req.Header.Del("Content-Length")
req.ContentLength = int64(len(decoded))
return decoded, nil
}
func decompressRequestBody(encoding string, raw []byte) ([]byte, error) {
switch encoding {
case "zstd":
dec, err := zstd.NewReader(bytes.NewReader(raw))
if err != nil {
return nil, err
}
defer dec.Close()
return io.ReadAll(dec)
case "gzip", "x-gzip":
gr, err := gzip.NewReader(bytes.NewReader(raw))
if err != nil {
return nil, err
}
defer gr.Close()
return io.ReadAll(gr)
case "deflate":
zr, err := zlib.NewReader(bytes.NewReader(raw))
if err != nil {
return nil, err
}
defer zr.Close()
return io.ReadAll(zr)
default:
return nil, errors.New("unsupported Content-Encoding")
}
}
package httputil
import (
"bytes"
"compress/gzip"
"compress/zlib"
"net/http"
"strings"
"testing"
"github.com/klauspost/compress/zstd"
)
const samplePayload = `{"model":"gpt-5.5","input":"hi","stream":false}`
func newRequestWithBody(t *testing.T, body []byte, encoding string) *http.Request {
t.Helper()
req, err := http.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body))
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
if encoding != "" {
req.Header.Set("Content-Encoding", encoding)
}
req.ContentLength = int64(len(body))
return req
}
func TestReadRequestBodyWithPrealloc_PassesThroughIdentity(t *testing.T) {
req := newRequestWithBody(t, []byte(samplePayload), "")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
}
func TestReadRequestBodyWithPrealloc_DecodesZstd(t *testing.T) {
enc, _ := zstd.NewWriter(nil)
compressed := enc.EncodeAll([]byte(samplePayload), nil)
_ = enc.Close()
req := newRequestWithBody(t, compressed, "zstd")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
if req.Header.Get("Content-Encoding") != "" {
t.Fatalf("Content-Encoding should be cleared after decoding")
}
if req.ContentLength != int64(len(samplePayload)) {
t.Fatalf("ContentLength not updated: %d", req.ContentLength)
}
}
func TestReadRequestBodyWithPrealloc_DecodesGzip(t *testing.T) {
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
if _, err := gw.Write([]byte(samplePayload)); err != nil {
t.Fatalf("gzip write: %v", err)
}
if err := gw.Close(); err != nil {
t.Fatalf("gzip close: %v", err)
}
req := newRequestWithBody(t, buf.Bytes(), "gzip")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
}
func TestReadRequestBodyWithPrealloc_DecodesDeflate(t *testing.T) {
var buf bytes.Buffer
zw := zlib.NewWriter(&buf)
if _, err := zw.Write([]byte(samplePayload)); err != nil {
t.Fatalf("zlib write: %v", err)
}
if err := zw.Close(); err != nil {
t.Fatalf("zlib close: %v", err)
}
req := newRequestWithBody(t, buf.Bytes(), "deflate")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
}
func TestReadRequestBodyWithPrealloc_RejectsUnsupportedEncoding(t *testing.T) {
req := newRequestWithBody(t, []byte(samplePayload), "br")
_, err := ReadRequestBodyWithPrealloc(req)
if err == nil {
t.Fatal("expected error for unsupported encoding, got nil")
}
if !strings.Contains(err.Error(), "br") {
t.Fatalf("error should mention encoding, got %v", err)
}
}
func TestReadRequestBodyWithPrealloc_RejectsCorruptZstd(t *testing.T) {
req := newRequestWithBody(t, []byte("not actually zstd"), "zstd")
_, err := ReadRequestBodyWithPrealloc(req)
if err == nil {
t.Fatal("expected error for corrupt zstd body, got nil")
}
}
func TestReadRequestBodyWithPrealloc_NilBody(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "/v1/responses", nil)
if err != nil {
t.Fatalf("NewRequest: %v", err)
}
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != nil {
t.Fatalf("expected nil body, got %q", got)
}
}
func TestReadRequestBodyWithPrealloc_RespectsIdentityEncoding(t *testing.T) {
req := newRequestWithBody(t, []byte(samplePayload), "identity")
got, err := ReadRequestBodyWithPrealloc(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(got) != samplePayload {
t.Fatalf("body mismatch: got %q", got)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment