diff --git a/middleware/gzip/gzip_test.go b/middleware/gzip/gzip_test.go index b9bafc6d..0e4df338 100644 --- a/middleware/gzip/gzip_test.go +++ b/middleware/gzip/gzip_test.go @@ -80,7 +80,6 @@ func TestGzipHandler(t *testing.T) { func nextFunc(shouldGzip bool) middleware.Handler { return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { - w.WriteHeader(200) w.Write([]byte("test")) if shouldGzip { if r.Header.Get("Accept-Encoding") != "" { diff --git a/middleware/gzip/response_filter.go b/middleware/gzip/response_filter.go index 87a34e60..b561649e 100644 --- a/middleware/gzip/response_filter.go +++ b/middleware/gzip/response_filter.go @@ -29,8 +29,9 @@ func (l LengthFilter) ShouldCompress(w http.ResponseWriter) bool { // gzip compressed data if ResponseFilters are satisfied or // uncompressed data otherwise. type ResponseFilterWriter struct { - filters []ResponseFilter - shouldCompress bool + filters []ResponseFilter + shouldCompress bool + statusCodeWritten bool *gzipResponseWriter } @@ -39,7 +40,7 @@ func NewResponseFilterWriter(filters []ResponseFilter, gz *gzipResponseWriter) * return &ResponseFilterWriter{filters: filters, gzipResponseWriter: gz} } -// Write wraps underlying Write method and compresses if filters +// Write wraps underlying WriteHeader method and compresses if filters // are satisfied. func (r *ResponseFilterWriter) WriteHeader(code int) { // Determine if compression should be used or not. @@ -62,11 +63,15 @@ func (r *ResponseFilterWriter) WriteHeader(code int) { } else { r.ResponseWriter.WriteHeader(code) } + r.statusCodeWritten = true } // Write wraps underlying Write method and compresses if filters // are satisfied func (r *ResponseFilterWriter) Write(b []byte) (int, error) { + if !r.statusCodeWritten { + r.WriteHeader(http.StatusOK) + } if r.shouldCompress { return r.gzipResponseWriter.Write(b) } diff --git a/middleware/gzip/response_filter_test.go b/middleware/gzip/response_filter_test.go index 73867e7f..75f72692 100644 --- a/middleware/gzip/response_filter_test.go +++ b/middleware/gzip/response_filter_test.go @@ -63,7 +63,6 @@ func TestResponseFilterWriter(t *testing.T) { for i, ts := range tests { server.Next = middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { w.Header().Set("Content-Length", fmt.Sprint(len(ts.body))) - w.WriteHeader(200) w.Write([]byte(ts.body)) return 200, nil })