From 904f149e5b2f0ae9eff09dacb4f1ec41c4a76298 Mon Sep 17 00:00:00 2001 From: Kevin Lin Date: Tue, 4 Aug 2020 10:50:38 +0800 Subject: [PATCH] reverse_proxy: fix bidirectional streams with encodings (fix #3606) (#3620) * reverse_proxy: fix bi-h2stream breaking gzip encode handle(#3606). * reverse_proxy: check http version of both sides to avoid affecting non-h2 upstream. * Minor cleanup; apply review suggestions Co-authored-by: Matthew Holt --- caddytest/integration/stream_test.go | 236 ++++++++++++++++++ .../caddyhttp/reverseproxy/reverseproxy.go | 4 +- modules/caddyhttp/reverseproxy/streaming.go | 24 +- 3 files changed, 259 insertions(+), 5 deletions(-) diff --git a/caddytest/integration/stream_test.go b/caddytest/integration/stream_test.go index c0ab32b5..b6447c68 100644 --- a/caddytest/integration/stream_test.go +++ b/caddytest/integration/stream_test.go @@ -1,7 +1,9 @@ package integration import ( + "compress/gzip" "context" + "crypto/rand" "fmt" "io" "io/ioutil" @@ -199,3 +201,237 @@ func testH2ToH2CStreamServeH2C(t *testing.T) *http.Server { } return server } + +// (see https://github.com/caddyserver/caddy/issues/3606 for use case) +func TestH2ToH1ChunkedResponse(t *testing.T) { + tester := caddytest.NewTester(t) + tester.InitServer(` +{ + "logging": { + "logs": { + "default": { + "level": "DEBUG" + } + } + }, + "apps": { + "http": { + "http_port": 9080, + "https_port": 9443, + "servers": { + "srv0": { + "listen": [ + ":9443" + ], + "routes": [ + { + "handle": [ + { + "handler": "subroute", + "routes": [ + { + "handle": [ + { + "encodings": { + "gzip": {} + }, + "handler": "encode" + } + ] + }, + { + "handle": [ + { + "handler": "reverse_proxy", + "upstreams": [ + { + "dial": "localhost:54321" + } + ] + } + ], + "match": [ + { + "path": [ + "/tov2ray" + ] + } + ] + } + ] + } + ], + "terminal": true + } + ], + "tls_connection_policies": [ + { + "certificate_selection": { + "any_tag": [ + "cert0" + ] + }, + "default_sni": "a.caddy.localhost" + } + ] + } + } + }, + "tls": { + "certificates": { + "load_files": [ + { + "certificate": "/a.caddy.localhost.crt", + "key": "/a.caddy.localhost.key", + "tags": [ + "cert0" + ] + } + ] + } + }, + "pki": { + "certificate_authorities": { + "local": { + "install_trust": false + } + } + } + } +} + `, "json") + + // need a large body here to trigger caddy's compression, larger than gzip.miniLength + expectedBody, err := GenerateRandomString(1024) + if err != nil { + t.Fatalf("generate expected body failed, err: %s", err) + } + + // start the server + server := testH2ToH1ChunkedResponseServeH1(t) + go server.ListenAndServe() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + server.Shutdown(ctx) + }() + + r, w := io.Pipe() + req := &http.Request{ + Method: "PUT", + Body: ioutil.NopCloser(r), + URL: &url.URL{ + Scheme: "https", + Host: "127.0.0.1:9443", + Path: "/tov2ray", + }, + Proto: "HTTP/2", + ProtoMajor: 2, + ProtoMinor: 0, + Header: make(http.Header), + } + // underlying transport will automaticlly add gzip + // req.Header.Set("Accept-Encoding", "gzip") + go func() { + fmt.Fprint(w, expectedBody) + w.Close() + }() + resp := tester.AssertResponseCode(req, 200) + if 200 != resp.StatusCode { + return + } + + defer resp.Body.Close() + bytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read the response body %s", err) + } + + body := string(bytes) + + if body != expectedBody { + t.Errorf("requesting \"%s\" expected response body \"%s\" but got \"%s\"", req.RequestURI, expectedBody, body) + } + return +} + +func testH2ToH1ChunkedResponseServeH1(t *testing.T) *http.Server { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + if r.Host != "127.0.0.1:9443" { + t.Errorf("r.Host doesn't match, %v!", r.Host) + w.WriteHeader(http.StatusNotFound) + return + } + + if !strings.HasPrefix(r.URL.Path, "/tov2ray") { + w.WriteHeader(http.StatusNotFound) + return + } + + defer r.Body.Close() + bytes, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("unable to read the response body %s", err) + } + + n := len(bytes) + + var writer io.Writer + if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + gw, err := gzip.NewWriterLevel(w, 5) + if err != nil { + t.Error("can't return gzip data") + w.WriteHeader(http.StatusInternalServerError) + return + } + defer gw.Close() + writer = gw + w.Header().Set("Content-Encoding", "gzip") + w.Header().Del("Content-Length") + w.WriteHeader(200) + } else { + writer = w + } + if n > 0 { + writer.Write(bytes[:]) + } + }) + + server := &http.Server{ + Addr: "127.0.0.1:54321", + Handler: handler, + } + return server +} + +// GenerateRandomBytes returns securely generated random bytes. +// It will return an error if the system's secure random +// number generator fails to function correctly, in which +// case the caller should not continue. +func GenerateRandomBytes(n int) ([]byte, error) { + b := make([]byte, n) + _, err := rand.Read(b) + // Note that err == nil only if we read len(b) bytes. + if err != nil { + return nil, err + } + + return b, nil +} + +// GenerateRandomString returns a securely generated random string. +// It will return an error if the system's secure random +// number generator fails to function correctly, in which +// case the caller should not continue. +func GenerateRandomString(n int) (string, error) { + const letters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-" + bytes, err := GenerateRandomBytes(n) + if err != nil { + return "", err + } + for i, b := range bytes { + bytes[i] = letters[b%byte(len(letters))] + } + return string(bytes), nil +} diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 0a53db4a..7fdf55a7 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -613,8 +613,8 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, di Dia // some apps need the response headers before starting to stream content with http2, // so it's important to explicitly flush the headers to the client before streaming the data. - // (see https://github.com/caddyserver/caddy/issues/3556 for use case) - if req.ProtoMajor == 2 && res.ContentLength == -1 { + // (see https://github.com/caddyserver/caddy/issues/3556 for use case and nuances) + if h.isBidirectionalStream(req, res) { if wf, ok := rw.(http.Flusher); ok { wf.Flush() } diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index 105ff32b..127c0f0f 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -96,15 +96,33 @@ func (h Handler) flushInterval(req *http.Request, res *http.Response) time.Durat return -1 // negative means immediately } - // for h2 and h2c upstream streaming data to client (issue #3556) - if req.ProtoMajor == 2 && res.ContentLength == -1 { + // for h2 and h2c upstream streaming data to client (issues #3556 and #3606) + if h.isBidirectionalStream(req, res) { return -1 } - // TODO: more specific cases? e.g. res.ContentLength == -1? (this TODO is from the std lib) + // TODO: more specific cases? e.g. res.ContentLength == -1? (this TODO is from the std lib, but + // strangely similar to our isBidirectionalStream function that we implemented ourselves) return time.Duration(h.FlushInterval) } +// isBidirectionalStream returns whether we should work in bi-directional stream mode. +// +// See https://github.com/caddyserver/caddy/pull/3620 for discussion of nuances. +func (h Handler) isBidirectionalStream(req *http.Request, res *http.Response) bool { + // We have to check the encoding here; only flush headers with identity encoding. + // Non-identity encoding might combine with "encode" directive, and in that case, + // if body size larger than enc.MinLength, upper level encode handle might have + // Content-Encoding header to write. + // (see https://github.com/caddyserver/caddy/issues/3606 for use case) + ae := req.Header.Get("Accept-Encoding") + + return req.ProtoMajor == 2 && + res.ProtoMajor == 2 && + res.ContentLength == -1 && + (ae == "identity" || ae == "") +} + func (h Handler) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { if flushInterval != 0 { if wf, ok := dst.(writeFlusher); ok {