diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 1b955839..64c2a7be 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -44,32 +44,62 @@ func TestReverseProxy(t *testing.T) { log.SetOutput(ioutil.Discard) defer log.SetOutput(os.Stderr) - verifyHeaders := func(headers http.Header, trailers http.Header) { - if headers.Get("X-Header") != "header-value" { - t.Error("Expected header 'X-Header' to be proxied properly") + testHeaderValue := []string{"header-value"} + testHeaders := http.Header{ + "X-Header-1": testHeaderValue, + "X-Header-2": testHeaderValue, + "X-Header-3": testHeaderValue, + } + testTrailerValue := []string{"trailer-value"} + testTrailers := http.Header{ + "X-Trailer-1": testTrailerValue, + "X-Trailer-2": testTrailerValue, + "X-Trailer-3": testTrailerValue, + } + verifyHeaderValues := func(actual http.Header, expected http.Header) bool { + if actual == nil { + t.Error("Expected headers") + return true } - if trailers == nil { - t.Error("Expected to receive trailers") + for k := range expected { + if expected.Get(k) != actual.Get(k) { + t.Errorf("Expected header '%s' to be proxied properly", k) + return true + } } - if trailers.Get("X-Trailer") != "trailer-value" { - t.Error("Expected header 'X-Trailer' to be proxied properly") + + return false + } + verifyHeadersTrailers := func(headers http.Header, trailers http.Header) { + if verifyHeaderValues(headers, testHeaders) || verifyHeaderValues(trailers, testTrailers) { + t.FailNow() } } - var requestReceived bool + requestReceived := false backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // read the body (even if it's empty) to make Go parse trailers io.Copy(ioutil.Discard, r.Body) - verifyHeaders(r.Header, r.Trailer) + verifyHeadersTrailers(r.Header, r.Trailer) requestReceived = true - w.Header().Set("Trailer", "X-Trailer") - w.Header().Set("X-Header", "header-value") + // Set headers. + copyHeader(w.Header(), testHeaders) + + // Only announce one of the trailers to test wether + // unannounced trailers are proxied correctly. + for k := range testTrailers { + w.Header().Set("Trailer", k) + break + } + w.WriteHeader(http.StatusOK) w.Write([]byte("Hello, client")) - w.Header().Set("X-Trailer", "trailer-value") + + // Set trailers. + shallowCopyTrailers(w.Header(), testTrailers, true) })) defer backend.Close() @@ -79,24 +109,37 @@ func TestReverseProxy(t *testing.T) { Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, } - // create request and response recorder - r := httptest.NewRequest("GET", "/", strings.NewReader("test")) - w := httptest.NewRecorder() - - r.ContentLength = -1 // force chunked encoding (required for trailers) - r.Header.Set("X-Header", "header-value") - r.Trailer = map[string][]string{ - "X-Trailer": {"trailer-value"}, + // Create the fake request body. + // This will copy "trailersToSet" to r.Trailer right before it is closed and + // thus test for us wether unannounced client trailers are proxied correctly. + body := &trailerTestStringReader{ + Reader: *strings.NewReader("test"), + trailersToSet: testTrailers, } + // Create the fake request with the above body. + r := httptest.NewRequest("GET", "/", body) + r.Trailer = make(http.Header) + body.request = r + + copyHeader(r.Header, testHeaders) + + // Only announce one of the trailers to test wether + // unannounced trailers are proxied correctly. + for k, v := range testTrailers { + r.Trailer[k] = v + break + } + + w := httptest.NewRecorder() p.ServeHTTP(w, r) + res := w.Result() if !requestReceived { t.Error("Expected backend to receive request, but it didn't") } - res := w.Result() - verifyHeaders(res.Header, res.Trailer) + verifyHeadersTrailers(res.Header, res.Trailer) // Make sure {upstream} placeholder is set r.Body = ioutil.NopCloser(strings.NewReader("test")) @@ -112,6 +155,21 @@ func TestReverseProxy(t *testing.T) { } } +// trailerTestStringReader is used to test unannounced trailers coming +// from a client which should properly be proxied to the upstream. +type trailerTestStringReader struct { + strings.Reader + request *http.Request + trailersToSet http.Header +} + +var _ io.ReadCloser = &trailerTestStringReader{} + +func (r *trailerTestStringReader) Close() error { + copyHeader(r.request.Trailer, r.trailersToSet) + return nil +} + func TestReverseProxyInsecureSkipVerify(t *testing.T) { log.SetOutput(ioutil.Discard) defer log.SetOutput(os.Stderr) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 23deefe5..5fc4aed9 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -318,30 +318,61 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, } pooledIoCopy(backendConn, conn) } else { + // NOTE: + // Closing the Body involves acquiring a mutex, which is a + // unnecessarily heavy operation, considering that this defer will + // pretty much never be executed with the Body still unclosed. + bodyOpen := true + closeBody := func() { + if bodyOpen { + res.Body.Close() + bodyOpen = false + } + } + defer closeBody() + + // Copy all headers over. + // res.Header does not include the "Trailer" header, + // which means we will have to do that manually below. copyHeader(rw.Header(), res.Header) - // The "Trailer" header isn't included in the Transport's response, - // at least for *http.Transport. Build it up from Trailer. - if len(res.Trailer) > 0 { - trailerKeys := make([]string, 0, len(res.Trailer)) + // The "Trailer" header isn't included in res' Header map, which + // is why we have to build one ourselves from res.Trailer. + // + // But res.Trailer does not necessarily contain all trailer keys at this + // point yet. The HTTP spec allows one to send "unannounced trailers" + // after a request and certain systems like gRPC make use of that. + announcedTrailerKeyCount := len(res.Trailer) + if announcedTrailerKeyCount > 0 { + vv := make([]string, 0, announcedTrailerKeyCount) for k := range res.Trailer { - trailerKeys = append(trailerKeys, k) + vv = append(vv, k) } - rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + rw.Header()["Trailer"] = vv } + // Now copy over the status code as well as the response body. rw.WriteHeader(res.StatusCode) - if len(res.Trailer) > 0 { + if announcedTrailerKeyCount > 0 { // Force chunking if we saw a response trailer. - // This prevents net/http from calculating the length for short - // bodies and adding a Content-Length. + // This prevents net/http from calculating the length + // for short bodies and adding a Content-Length. if fl, ok := rw.(http.Flusher); ok { fl.Flush() } } rp.copyResponse(rw, res.Body) - res.Body.Close() // close now, instead of defer, to populate res.Trailer - copyHeader(rw.Header(), res.Trailer) + + // Now close the body to fully populate res.Trailer. + closeBody() + + // Since Go does not remove keys from res.Trailer we + // can safely do a length comparison to check wether + // we received further, unannounced trailers. + // + // Most of the time forceSetTrailers should be false. + forceSetTrailers := len(res.Trailer) != announcedTrailerKeyCount + shallowCopyTrailers(rw.Header(), res.Trailer, forceSetTrailers) } return nil @@ -391,6 +422,22 @@ func copyHeader(dst, src http.Header) { } } +// shallowCopyTrailers copies all headers from srcTrailer to dstHeader. +// +// If forceSetTrailers is set to true, the http.TrailerPrefix will be added to +// all srcTrailer key names. Otherwise the Go stdlib will ignore all keys +// which weren't listed in the Trailer map before submitting the Response. +// +// WARNING: Only a shallow copy will be created! +func shallowCopyTrailers(dstHeader, srcTrailer http.Header, forceSetTrailers bool) { + for k, vv := range srcTrailer { + if forceSetTrailers { + k = http.TrailerPrefix + k + } + dstHeader[k] = vv + } +} + // Hop-by-hop headers. These are removed when sent to the backend. // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html var hopHeaders = []string{