From 5300949e0def70411fde307afad4c15b6cc22dfd Mon Sep 17 00:00:00 2001
From: Matthew Holt <mholt@users.noreply.github.com>
Date: Thu, 10 Oct 2019 15:36:28 -0600
Subject: [PATCH] caddyhttp: Make responseRecorder capable of counting body
 size

---
 modules/caddyhttp/responsewriter.go | 28 +++++++++++++++++++++++-----
 1 file changed, 23 insertions(+), 5 deletions(-)

diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go
index db5d0649a..344298f2f 100644
--- a/modules/caddyhttp/responsewriter.go
+++ b/modules/caddyhttp/responsewriter.go
@@ -80,6 +80,7 @@ type responseRecorder struct {
 	buf          *bytes.Buffer
 	shouldBuffer func(status int) bool
 	stream       bool
+	size         int
 }
 
 // NewResponseRecorder returns a new ResponseRecorder that can be
@@ -100,7 +101,9 @@ type responseRecorder struct {
 // the Buffered() method returns true. If the response was not
 // buffered, Buffered() will return false and that means the
 // response bypassed the recorder and was written directly to the
-// underlying writer.
+// underlying writer. If shouldBuffer is nil, the response will
+// never be buffered (it will always be streamed directly), and
+// buf can also safely be nil.
 //
 // Before calling this function in a middleware handler, make a
 // new buffer or obtain one from a pool (use the sync.Pool) type.
@@ -130,9 +133,10 @@ func (rr *responseRecorder) WriteHeader(statusCode int) {
 
 	// decide whether we should buffer the response
 	if rr.shouldBuffer == nil {
-		return
+		rr.stream = true
+	} else {
+		rr.stream = !rr.shouldBuffer(rr.statusCode)
 	}
-	rr.stream = !rr.shouldBuffer(rr.statusCode)
 	if rr.stream {
 		rr.ResponseWriterWrapper.WriteHeader(rr.statusCode)
 	}
@@ -140,10 +144,17 @@ func (rr *responseRecorder) WriteHeader(statusCode int) {
 
 func (rr *responseRecorder) Write(data []byte) (int, error) {
 	rr.WriteHeader(http.StatusOK)
+	var n int
+	var err error
 	if rr.stream {
-		return rr.ResponseWriterWrapper.Write(data)
+		n, err = rr.ResponseWriterWrapper.Write(data)
+	} else {
+		n, err = rr.buf.Write(data)
 	}
-	return rr.buf.Write(data)
+	if err == nil {
+		rr.size += n
+	}
+	return n, err
 }
 
 // Status returns the status code that was written, if any.
@@ -151,6 +162,12 @@ func (rr *responseRecorder) Status() int {
 	return rr.statusCode
 }
 
+// Size returns the number of bytes written,
+// not including the response headers.
+func (rr *responseRecorder) Size() int {
+	return rr.size
+}
+
 // Buffer returns the body buffer that rr was created with.
 // You should still have your original pointer, though.
 func (rr *responseRecorder) Buffer() *bytes.Buffer {
@@ -169,6 +186,7 @@ type ResponseRecorder interface {
 	Status() int
 	Buffer() *bytes.Buffer
 	Buffered() bool
+	Size() int
 }
 
 // Interface guards