diff --git a/middleware/gzip/gzip.go b/middleware/gzip/gzip.go index 39a92266..26610daa 100644 --- a/middleware/gzip/gzip.go +++ b/middleware/gzip/gzip.go @@ -57,7 +57,7 @@ outer: return http.StatusInternalServerError, err } defer gzipWriter.Close() - gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w} + gz := &gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w} var rw http.ResponseWriter // if no response filter is used @@ -104,21 +104,26 @@ func newWriter(c Config, w io.Writer) (*gzip.Writer, error) { type gzipResponseWriter struct { io.Writer http.ResponseWriter + statusCodeWritten bool } // WriteHeader wraps the underlying WriteHeader method to prevent // problems with conflicting headers from proxied backends. For // example, a backend system that calculates Content-Length would // be wrong because it doesn't know it's being gzipped. -func (w gzipResponseWriter) WriteHeader(code int) { +func (w *gzipResponseWriter) WriteHeader(code int) { w.Header().Del("Content-Length") w.Header().Set("Content-Encoding", "gzip") w.Header().Add("Vary", "Accept-Encoding") w.ResponseWriter.WriteHeader(code) + w.statusCodeWritten = true } // Write wraps the underlying Write method to do compression. -func (w gzipResponseWriter) Write(b []byte) (int, error) { +func (w *gzipResponseWriter) Write(b []byte) (int, error) { + if !w.statusCodeWritten { + w.WriteHeader(http.StatusOK) + } if w.Header().Get("Content-Type") == "" { w.Header().Set("Content-Type", http.DetectContentType(b)) } diff --git a/middleware/gzip/gzip_test.go b/middleware/gzip/gzip_test.go index c35c99c6..0e4df338 100644 --- a/middleware/gzip/gzip_test.go +++ b/middleware/gzip/gzip_test.go @@ -80,7 +80,6 @@ func TestGzipHandler(t *testing.T) { func nextFunc(shouldGzip bool) middleware.Handler { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - w.WriteHeader(200) w.Write([]byte("test")) if shouldGzip { if r.Header.Get("Accept-Encoding") != "" { @@ -92,7 +91,7 @@ func nextFunc(shouldGzip bool) middleware.Handler { if w.Header().Get("Vary") != "Accept-Encoding" { return 0, fmt.Errorf("Vary must be Accept-Encoding, found %v", r.Header.Get("Vary")) } - if _, ok := w.(gzipResponseWriter); !ok { + if _, ok := w.(*gzipResponseWriter); !ok { return 0, fmt.Errorf("ResponseWriter should be gzipResponseWriter, found %T", w) } return 0, nil @@ -103,7 +102,7 @@ func nextFunc(shouldGzip bool) middleware.Handler { if w.Header().Get("Content-Encoding") == "gzip" { return 0, fmt.Errorf("Content-Encoding must not be gzip, found gzip") } - if _, ok := w.(gzipResponseWriter); ok { + if _, ok := w.(*gzipResponseWriter); ok { return 0, fmt.Errorf("ResponseWriter should not be gzipResponseWriter") } return 0, nil diff --git a/middleware/gzip/response_filter.go b/middleware/gzip/response_filter.go index c599b3e1..b561649e 100644 --- a/middleware/gzip/response_filter.go +++ b/middleware/gzip/response_filter.go @@ -29,17 +29,18 @@ func (l LengthFilter) ShouldCompress(w http.ResponseWriter) bool { // gzip compressed data if ResponseFilters are satisfied or // uncompressed data otherwise. type ResponseFilterWriter struct { - filters []ResponseFilter - shouldCompress bool - gzipResponseWriter + filters []ResponseFilter + shouldCompress bool + statusCodeWritten bool + *gzipResponseWriter } // NewResponseFilterWriter creates and initializes a new ResponseFilterWriter. -func NewResponseFilterWriter(filters []ResponseFilter, gz gzipResponseWriter) *ResponseFilterWriter { +func NewResponseFilterWriter(filters []ResponseFilter, gz *gzipResponseWriter) *ResponseFilterWriter { return &ResponseFilterWriter{filters: filters, gzipResponseWriter: gz} } -// Write wraps underlying Write method and compresses if filters +// Write wraps underlying WriteHeader method and compresses if filters // are satisfied. func (r *ResponseFilterWriter) WriteHeader(code int) { // Determine if compression should be used or not. @@ -62,11 +63,15 @@ func (r *ResponseFilterWriter) WriteHeader(code int) { } else { r.ResponseWriter.WriteHeader(code) } + r.statusCodeWritten = true } // Write wraps underlying Write method and compresses if filters // are satisfied func (r *ResponseFilterWriter) Write(b []byte) (int, error) { + if !r.statusCodeWritten { + r.WriteHeader(http.StatusOK) + } if r.shouldCompress { return r.gzipResponseWriter.Write(b) } diff --git a/middleware/gzip/response_filter_test.go b/middleware/gzip/response_filter_test.go index cd7c7191..75f72692 100644 --- a/middleware/gzip/response_filter_test.go +++ b/middleware/gzip/response_filter_test.go @@ -33,7 +33,7 @@ func TestLengthFilter(t *testing.T) { for j, filter := range filters { r := httptest.NewRecorder() r.Header().Set("Content-Length", fmt.Sprint(ts.length)) - wWriter := NewResponseFilterWriter([]ResponseFilter{filter}, gzipResponseWriter{gzip.NewWriter(r), r}) + wWriter := NewResponseFilterWriter([]ResponseFilter{filter}, &gzipResponseWriter{gzip.NewWriter(r), r, false}) if filter.ShouldCompress(wWriter) != ts.shouldCompress[j] { t.Errorf("Test %v: Expected %v found %v", i, ts.shouldCompress[j], filter.ShouldCompress(r)) } @@ -63,7 +63,6 @@ func TestResponseFilterWriter(t *testing.T) { for i, ts := range tests { server.Next = middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { w.Header().Set("Content-Length", fmt.Sprint(len(ts.body))) - w.WriteHeader(200) w.Write([]byte(ts.body)) return 200, nil })