From abf5ab340ed76792214ae80c62df7abe0ad1b8a8 Mon Sep 17 00:00:00 2001
From: Matthew Holt <mholt@users.noreply.github.com>
Date: Tue, 15 Oct 2019 14:07:10 -0600
Subject: [PATCH] caddyhttp: Improve ResponseRecorder to buffer headers

---
 modules/caddyhttp/caddyhttp.go            |  13 +++
 modules/caddyhttp/httpcache/httpcache.go  |   5 +-
 modules/caddyhttp/markdown/markdown.go    |   6 +-
 modules/caddyhttp/responsewriter.go       | 119 ++++++++++++++++------
 modules/caddyhttp/templates/templates.go  |  18 ++--
 modules/caddyhttp/templates/tplcontext.go |   2 +-
 6 files changed, 116 insertions(+), 47 deletions(-)

diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go
index 5631b300b..29a5ab071 100644
--- a/modules/caddyhttp/caddyhttp.go
+++ b/modules/caddyhttp/caddyhttp.go
@@ -593,6 +593,19 @@ func (ws WeakString) String() string {
 	return string(ws)
 }
 
+// CopyHeader copies HTTP headers by completely
+// replacing dest with src. (This allows deletions
+// to be propagated, assuming src started as a
+// consistent copy of dest.)
+func CopyHeader(dest, src http.Header) {
+	for field := range dest {
+		delete(dest, field)
+	}
+	for field, val := range src {
+		dest[field] = val
+	}
+}
+
 // StatusCodeMatches returns true if a real HTTP status code matches
 // the configured status code, which may be either a real HTTP status
 // code or an integer representing a class of codes (e.g. 4 for all
diff --git a/modules/caddyhttp/httpcache/httpcache.go b/modules/caddyhttp/httpcache/httpcache.go
index 1b2cfd2e9..0b49c7ee0 100644
--- a/modules/caddyhttp/httpcache/httpcache.go
+++ b/modules/caddyhttp/httpcache/httpcache.go
@@ -130,8 +130,7 @@ func (c *Cache) getter(ctx groupcache.Context, key string, dest groupcache.Sink)
 
 	// we need to record the response if we are to cache it; only cache if
 	// request is successful (TODO: there's probably much more nuance needed here)
-	var rr caddyhttp.ResponseRecorder
-	rr = caddyhttp.NewResponseRecorder(combo.rw, buf, func(status int) bool {
+	rr := caddyhttp.NewResponseRecorder(combo.rw, buf, func(status int, header http.Header) bool {
 		shouldBuf := status < 300
 
 		if shouldBuf {
@@ -141,7 +140,7 @@ func (c *Cache) getter(ctx groupcache.Context, key string, dest groupcache.Sink)
 			// the rest will be the body, which will be written
 			// implicitly for us by the recorder
 			err := gob.NewEncoder(buf).Encode(headerAndStatus{
-				Header: rr.Header(),
+				Header: header,
 				Status: status,
 			})
 			if err != nil {
diff --git a/modules/caddyhttp/markdown/markdown.go b/modules/caddyhttp/markdown/markdown.go
index 122aad6e4..5ff18b88e 100644
--- a/modules/caddyhttp/markdown/markdown.go
+++ b/modules/caddyhttp/markdown/markdown.go
@@ -48,8 +48,8 @@ func (m Markdown) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
 	buf.Reset()
 	defer bufPool.Put(buf)
 
-	shouldBuf := func(status int) bool {
-		return strings.HasPrefix(w.Header().Get("Content-Type"), "text/")
+	shouldBuf := func(status int, header http.Header) bool {
+		return strings.HasPrefix(header.Get("Content-Type"), "text/")
 	}
 
 	rec := caddyhttp.NewResponseRecorder(w, buf, shouldBuf)
@@ -62,6 +62,8 @@ func (m Markdown) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
 		return nil
 	}
 
+	caddyhttp.CopyHeader(w.Header(), rec.Header())
+
 	output := blackfriday.Run(buf.Bytes())
 
 	w.Header().Set("Content-Length", strconv.Itoa(len(output)))
diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go
index 344298f2f..5beb40ea3 100644
--- a/modules/caddyhttp/responsewriter.go
+++ b/modules/caddyhttp/responsewriter.go
@@ -18,6 +18,7 @@ import (
 	"bufio"
 	"bytes"
 	"fmt"
+	"io"
 	"net"
 	"net/http"
 )
@@ -78,52 +79,89 @@ type responseRecorder struct {
 	wroteHeader  bool
 	statusCode   int
 	buf          *bytes.Buffer
-	shouldBuffer func(status int) bool
+	shouldBuffer ShouldBufferFunc
 	stream       bool
 	size         int
+	header       http.Header
 }
 
 // NewResponseRecorder returns a new ResponseRecorder that can be
-// used instead of a real http.ResponseWriter. The recorder is useful
-// for middlewares which need to buffer a responder's response and
-// process it in its entirety before actually allowing the response to
-// be written. Of course, this has a performance overhead, but
-// sometimes there is no way to avoid buffering the whole response.
-// Still, if at all practical, middlewares should strive to stream
+// used instead of a standard http.ResponseWriter. The recorder is
+// useful for middlewares which need to buffer a response and
+// potentially process its entire body before actually writing the
+// response to the underlying writer. Of course, buffering the entire
+// body has a memory overhead, but sometimes there is no way to avoid
+// buffering the whole response, hence the existence of this type.
+// Still, if at all practical, handlers should strive to stream
 // responses by wrapping Write and WriteHeader methods instead of
 // buffering whole response bodies.
 //
-// Recorders optionally buffer the response. When the headers are
-// to be written, shouldBuffer will be called with the status
-// code that is being written. The rest of the headers can be read
-// from w.Header(). If shouldBuffer returns true, the response
-// will be buffered. You can know the response was buffered if
-// the Buffered() method returns true. If the response was not
-// buffered, Buffered() will return false and that means the
-// response bypassed the recorder and was written directly to the
-// underlying writer. If shouldBuffer is nil, the response will
-// never be buffered (it will always be streamed directly), and
-// buf can also safely be nil.
+// Buffering is actually optional. The shouldBuffer function will
+// be called just before the headers are written. If it returns
+// true, the headers and body will be buffered by this recorder
+// and not written to the underlying writer; if false, the headers
+// will be written immediately and the body will be streamed out
+// directly to the underlying writer. If shouldBuffer is nil,
+// the response will never be buffered and will always be streamed
+// directly to the writer.
 //
-// Before calling this function in a middleware handler, make a
-// new buffer or obtain one from a pool (use the sync.Pool) type.
-// Using a pool is generally recommended for performance gains;
-// do profiling to ensure this is the case. If using a pool, be
-// sure to reset the buffer before using it.
+// You can know if shouldBuffer returned true by calling Buffered().
 //
-// The returned recorder can be used in place of w when calling
-// the next handler in the chain. When that handler returns, you
-// can read the status code from the recorder's Status() method.
-// The response body fills buf if it was buffered, and the headers
-// are available via w.Header().
-func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer func(status int) bool) ResponseRecorder {
+// The provided buffer buf should be obtained from a pool for best
+// performance (see the sync.Pool type).
+//
+// Proper usage of a recorder looks like this:
+//
+//     rec := caddyhttp.NewResponseRecorder(w, buf, shouldBuffer)
+//     err := next.ServeHTTP(rec, req)
+//     if err != nil {
+//         return err
+//     }
+//     if !rec.Buffered() {
+//         return nil
+//     }
+//     // process the buffered response here
+//
+// After a response has been buffered, remember that any upstream header
+// manipulations are only manifest in the recorder's Header(), not the
+// Header() of the underlying ResponseWriter. Thus if you wish to inspect
+// or change response headers, you either need to use rec.Header(), or
+// copy rec.Header() into w.Header() first (see caddyhttp.CopyHeader).
+//
+// Once you are ready to write the response, there are two ways you can do
+// it. The easier way is to have the recorder do it:
+//
+//     rec.WriteResponse()
+//
+// This writes the recorded response headers as well as the buffered body.
+// Or, you may wish to do it yourself, especially if you manipulated the
+// buffered body. First you will need to copy the recorded headers, then
+// write the headers with the recorded status code, then write the body
+// (this example writes the recorder's body buffer, but you might have
+// your own body to write instead):
+//
+//     caddyhttp.CopyHeader(w.Header(), rec.Header())
+//     w.WriteHeader(rec.Status())
+//     io.Copy(w, rec.Buffer())
+//
+func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer ShouldBufferFunc) ResponseRecorder {
+	// copy the current response header into this buffer so
+	// that any header manipulations on the buffered header
+	// are consistent with what would be written out
+	hdr := make(http.Header)
+	CopyHeader(hdr, w.Header())
 	return &responseRecorder{
 		ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: w},
 		buf:                   buf,
 		shouldBuffer:          shouldBuffer,
+		header:                hdr,
 	}
 }
 
+func (rr *responseRecorder) Header() http.Header {
+	return rr.header
+}
+
 func (rr *responseRecorder) WriteHeader(statusCode int) {
 	if rr.wroteHeader {
 		return
@@ -135,9 +173,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) {
 	if rr.shouldBuffer == nil {
 		rr.stream = true
 	} else {
-		rr.stream = !rr.shouldBuffer(rr.statusCode)
+		rr.stream = !rr.shouldBuffer(rr.statusCode, rr.header)
 	}
+
+	// if not buffered, immediately write header
 	if rr.stream {
+		CopyHeader(rr.ResponseWriterWrapper.Header(), rr.header)
 		rr.ResponseWriterWrapper.WriteHeader(rr.statusCode)
 	}
 }
@@ -179,16 +220,32 @@ func (rr *responseRecorder) Buffered() bool {
 	return !rr.stream
 }
 
+func (rr *responseRecorder) WriteResponse() error {
+	if rr.stream {
+		return nil
+	}
+	CopyHeader(rr.ResponseWriterWrapper.Header(), rr.header)
+	rr.ResponseWriterWrapper.WriteHeader(rr.statusCode)
+	_, err := io.Copy(rr.ResponseWriterWrapper, rr.buf)
+	return err
+}
 // ResponseRecorder is a http.ResponseWriter that records
-// responses instead of writing them to the client.
+// responses instead of writing them to the client. See
+// docs for NewResponseRecorder for proper usage.
 type ResponseRecorder interface {
 	HTTPInterfaces
 	Status() int
 	Buffer() *bytes.Buffer
 	Buffered() bool
 	Size() int
+	WriteResponse() error
 }
 
+// ShouldBufferFunc is a function that returns true if the
+// response should be buffered, given the pending HTTP status
+// code and response headers.
+type ShouldBufferFunc func(status int, header http.Header) bool
+
 // Interface guards
 var (
 	_ HTTPInterfaces   = (*ResponseWriterWrapper)(nil)
diff --git a/modules/caddyhttp/templates/templates.go b/modules/caddyhttp/templates/templates.go
index 05a2f633a..e9c1da819 100644
--- a/modules/caddyhttp/templates/templates.go
+++ b/modules/caddyhttp/templates/templates.go
@@ -17,7 +17,6 @@ package templates
 import (
 	"bytes"
 	"fmt"
-	"io"
 	"net/http"
 	"strconv"
 	"strings"
@@ -71,8 +70,8 @@ func (t *Templates) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy
 
 	// shouldBuf determines whether to execute templates on this response,
 	// since generally we will not want to execute for images or CSS, etc.
-	shouldBuf := func(status int) bool {
-		ct := w.Header().Get("Content-Type")
+	shouldBuf := func(status int, header http.Header) bool {
+		ct := header.Get("Content-Type")
 		for _, mt := range t.MIMETypes {
 			if strings.Contains(ct, mt) {
 				return true
@@ -96,18 +95,17 @@ func (t *Templates) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy
 		return err
 	}
 
-	w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
-	w.Header().Del("Accept-Ranges") // we don't know ranges for dynamically-created content
-	w.Header().Del("Last-Modified") // useless for dynamic content since it's always changing
+	rec.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
+	rec.Header().Del("Accept-Ranges") // we don't know ranges for dynamically-created content
+	rec.Header().Del("Last-Modified") // useless for dynamic content since it's always changing
 
 	// we don't know a way to guickly generate etag for dynamic content,
 	// but we can convert this to a weak etag to kind of indicate that
-	if etag := w.Header().Get("ETag"); etag != "" {
-		w.Header().Set("ETag", "W/"+etag)
+	if etag := rec.Header().Get("Etag"); etag != "" {
+		rec.Header().Set("Etag", "W/"+etag)
 	}
 
-	w.WriteHeader(rec.Status())
-	io.Copy(w, buf)
+	rec.WriteResponse()
 
 	return nil
 }
diff --git a/modules/caddyhttp/templates/tplcontext.go b/modules/caddyhttp/templates/tplcontext.go
index 5b74623e5..40d137079 100644
--- a/modules/caddyhttp/templates/tplcontext.go
+++ b/modules/caddyhttp/templates/tplcontext.go
@@ -80,7 +80,7 @@ func (c templateContext) Include(filename string, args ...interface{}) (template
 // If it is not trusted, be sure to use escaping functions yourself.
 func (c templateContext) HTTPInclude(uri string) (template.HTML, error) {
 	if c.Req.Header.Get(recursionPreventionHeader) == "1" {
-		return "", fmt.Errorf("virtual include cycle")
+		return "", fmt.Errorf("virtual request cycle")
 	}
 
 	buf := bufPool.Get().(*bytes.Buffer)