diff --git a/middleware/gzip/gzip.go b/middleware/gzip/gzip.go index b5866f682..65256b2ff 100644 --- a/middleware/gzip/gzip.go +++ b/middleware/gzip/gzip.go @@ -3,6 +3,7 @@ package gzip import ( + "bytes" "compress/gzip" "fmt" "io" @@ -47,9 +48,13 @@ outer: // Delete this header so gzipping is not repeated later in the chain r.Header.Del("Accept-Encoding") - w.Header().Set("Content-Encoding", "gzip") - w.Header().Set("Vary", "Accept-Encoding") - gzipWriter, err := newWriter(c, w) + // gzipWriter modifies underlying writer at init, + // use a buffer instead to leave ResponseWriter in + // original form. + var buf = &bytes.Buffer{} + defer buf.Reset() + + gzipWriter, err := newWriter(c, buf) if err != nil { // should not happen return http.StatusInternalServerError, err @@ -60,6 +65,8 @@ outer: var rw http.ResponseWriter // if no response filter is used if len(c.ResponseFilters) == 0 { + // replace buffer with ResponseWriter + gzipWriter.Reset(w) rw = gz } else { // wrap gzip writer with ResponseFilterWriter @@ -88,7 +95,7 @@ outer: // newWriter create a new Gzip Writer based on the compression level. // If the level is valid (i.e. between 1 and 9), it uses the level. // Otherwise, it uses default compression level. -func newWriter(c Config, w http.ResponseWriter) (*gzip.Writer, error) { +func newWriter(c Config, w io.Writer) (*gzip.Writer, error) { if c.Level >= gzip.BestSpeed && c.Level <= gzip.BestCompression { return gzip.NewWriterLevel(w, c.Level) } @@ -108,6 +115,8 @@ type gzipResponseWriter struct { // be wrong because it doesn't know it's being gzipped. func (w gzipResponseWriter) WriteHeader(code int) { w.Header().Del("Content-Length") + w.Header().Set("Content-Encoding", "gzip") + w.Header().Set("Vary", "Accept-Encoding") w.ResponseWriter.WriteHeader(code) } diff --git a/middleware/gzip/gzip_test.go b/middleware/gzip/gzip_test.go index 11ce6b209..c35c99c63 100644 --- a/middleware/gzip/gzip_test.go +++ b/middleware/gzip/gzip_test.go @@ -80,6 +80,8 @@ func TestGzipHandler(t *testing.T) { func nextFunc(shouldGzip bool) middleware.Handler { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + w.WriteHeader(200) + w.Write([]byte("test")) if shouldGzip { if r.Header.Get("Accept-Encoding") != "" { return 0, fmt.Errorf("Accept-Encoding header not expected") diff --git a/middleware/gzip/response_filter.go b/middleware/gzip/response_filter.go index fcfc8be51..034fd6123 100644 --- a/middleware/gzip/response_filter.go +++ b/middleware/gzip/response_filter.go @@ -1,6 +1,7 @@ package gzip import ( + "compress/gzip" "net/http" "strconv" ) @@ -29,7 +30,6 @@ func (l LengthFilter) ShouldCompress(w http.ResponseWriter) bool { // uncompressed data otherwise. type ResponseFilterWriter struct { filters []ResponseFilter - validated bool shouldCompress bool gzipResponseWriter } @@ -40,21 +40,33 @@ func NewResponseFilterWriter(filters []ResponseFilter, gz gzipResponseWriter) *R } // Write wraps underlying Write method and compresses if filters -// are satisfied -func (r *ResponseFilterWriter) Write(b []byte) (int, error) { - // One time validation to determine if compression should - // be used or not. - if !r.validated { - r.shouldCompress = true - for _, filter := range r.filters { - if !filter.ShouldCompress(r) { - r.shouldCompress = false - break - } +// are satisfied. +func (r *ResponseFilterWriter) WriteHeader(code int) { + // Determine if compression should be used or not. + r.shouldCompress = true + for _, filter := range r.filters { + if !filter.ShouldCompress(r) { + r.shouldCompress = false + break } - r.validated = true } + if r.shouldCompress { + // replace buffer with ResponseWriter + if gzWriter, ok := r.gzipResponseWriter.Writer.(*gzip.Writer); ok { + gzWriter.Reset(r.ResponseWriter) + } + // use gzip WriteHeader to include and delete + // necessary headers + r.gzipResponseWriter.WriteHeader(code) + } else { + r.ResponseWriter.WriteHeader(code) + } +} + +// Write wraps underlying Write method and compresses if filters +// are satisfied +func (r *ResponseFilterWriter) Write(b []byte) (int, error) { if r.shouldCompress { return r.gzipResponseWriter.Write(b) } diff --git a/middleware/gzip/response_filter_test.go b/middleware/gzip/response_filter_test.go index 1a5a1b4f3..cd7c71917 100644 --- a/middleware/gzip/response_filter_test.go +++ b/middleware/gzip/response_filter_test.go @@ -3,8 +3,11 @@ package gzip import ( "compress/gzip" "fmt" + "net/http" "net/http/httptest" "testing" + + "github.com/mholt/caddy/middleware" ) func TestLengthFilter(t *testing.T) { @@ -30,7 +33,8 @@ func TestLengthFilter(t *testing.T) { for j, filter := range filters { r := httptest.NewRecorder() r.Header().Set("Content-Length", fmt.Sprint(ts.length)) - if filter.ShouldCompress(r) != ts.shouldCompress[j] { + wWriter := NewResponseFilterWriter([]ResponseFilter{filter}, gzipResponseWriter{gzip.NewWriter(r), r}) + if filter.ShouldCompress(wWriter) != ts.shouldCompress[j] { t.Errorf("Test %v: Expected %v found %v", i, ts.shouldCompress[j], filter.ShouldCompress(r)) } } @@ -47,16 +51,32 @@ func TestResponseFilterWriter(t *testing.T) { {"Hello \t\t\nfrom gzip", true}, {"Hello gzip\n", false}, } + filters := []ResponseFilter{ LengthFilter(15), } + + server := Gzip{Configs: []Config{ + {ResponseFilters: filters}, + }} + for i, ts := range tests { + server.Next = middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + w.Header().Set("Content-Length", fmt.Sprint(len(ts.body))) + w.WriteHeader(200) + w.Write([]byte(ts.body)) + return 200, nil + }) + + r := urlRequest("/") + r.Header.Set("Accept-Encoding", "gzip") + w := httptest.NewRecorder() - w.Header().Set("Content-Length", fmt.Sprint(len(ts.body))) - gz := gzipResponseWriter{gzip.NewWriter(w), w} - rw := NewResponseFilterWriter(filters, gz) - rw.Write([]byte(ts.body)) + + server.ServeHTTP(w, r) + resp := w.Body.String() + if !ts.shouldCompress { if resp != ts.body { t.Errorf("Test %v: No compression expected, found %v", i, resp)