Fix deleted Content-Length header bug.

This commit is contained in:
Abiola Ibrahim 2015-12-08 12:01:24 +01:00
parent 8631f33940
commit 23631cfaca
4 changed files with 65 additions and 22 deletions

View file

@ -3,6 +3,7 @@
package gzip
import (
"bytes"
"compress/gzip"
"fmt"
"io"
@ -47,9 +48,13 @@ outer:
// Delete this header so gzipping is not repeated later in the chain
r.Header.Del("Accept-Encoding")
w.Header().Set("Content-Encoding", "gzip")
w.Header().Set("Vary", "Accept-Encoding")
gzipWriter, err := newWriter(c, w)
// gzipWriter modifies underlying writer at init,
// use a buffer instead to leave ResponseWriter in
// original form.
var buf = &bytes.Buffer{}
defer buf.Reset()
gzipWriter, err := newWriter(c, buf)
if err != nil {
// should not happen
return http.StatusInternalServerError, err
@ -60,6 +65,8 @@ outer:
var rw http.ResponseWriter
// if no response filter is used
if len(c.ResponseFilters) == 0 {
// replace buffer with ResponseWriter
gzipWriter.Reset(w)
rw = gz
} else {
// wrap gzip writer with ResponseFilterWriter
@ -88,7 +95,7 @@ outer:
// newWriter create a new Gzip Writer based on the compression level.
// If the level is valid (i.e. between 1 and 9), it uses the level.
// Otherwise, it uses default compression level.
func newWriter(c Config, w http.ResponseWriter) (*gzip.Writer, error) {
func newWriter(c Config, w io.Writer) (*gzip.Writer, error) {
if c.Level >= gzip.BestSpeed && c.Level <= gzip.BestCompression {
return gzip.NewWriterLevel(w, c.Level)
}
@ -108,6 +115,8 @@ type gzipResponseWriter struct {
// be wrong because it doesn't know it's being gzipped.
func (w gzipResponseWriter) WriteHeader(code int) {
w.Header().Del("Content-Length")
w.Header().Set("Content-Encoding", "gzip")
w.Header().Set("Vary", "Accept-Encoding")
w.ResponseWriter.WriteHeader(code)
}

View file

@ -80,6 +80,8 @@ func TestGzipHandler(t *testing.T) {
func nextFunc(shouldGzip bool) middleware.Handler {
return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
w.WriteHeader(200)
w.Write([]byte("test"))
if shouldGzip {
if r.Header.Get("Accept-Encoding") != "" {
return 0, fmt.Errorf("Accept-Encoding header not expected")

View file

@ -1,6 +1,7 @@
package gzip
import (
"compress/gzip"
"net/http"
"strconv"
)
@ -29,7 +30,6 @@ func (l LengthFilter) ShouldCompress(w http.ResponseWriter) bool {
// uncompressed data otherwise.
type ResponseFilterWriter struct {
filters []ResponseFilter
validated bool
shouldCompress bool
gzipResponseWriter
}
@ -40,11 +40,9 @@ func NewResponseFilterWriter(filters []ResponseFilter, gz gzipResponseWriter) *R
}
// Write wraps underlying Write method and compresses if filters
// are satisfied
func (r *ResponseFilterWriter) Write(b []byte) (int, error) {
// One time validation to determine if compression should
// be used or not.
if !r.validated {
// are satisfied.
func (r *ResponseFilterWriter) WriteHeader(code int) {
// Determine if compression should be used or not.
r.shouldCompress = true
for _, filter := range r.filters {
if !filter.ShouldCompress(r) {
@ -52,9 +50,23 @@ func (r *ResponseFilterWriter) Write(b []byte) (int, error) {
break
}
}
r.validated = true
}
if r.shouldCompress {
// replace buffer with ResponseWriter
if gzWriter, ok := r.gzipResponseWriter.Writer.(*gzip.Writer); ok {
gzWriter.Reset(r.ResponseWriter)
}
// use gzip WriteHeader to include and delete
// necessary headers
r.gzipResponseWriter.WriteHeader(code)
} else {
r.ResponseWriter.WriteHeader(code)
}
}
// Write wraps underlying Write method and compresses if filters
// are satisfied
func (r *ResponseFilterWriter) Write(b []byte) (int, error) {
if r.shouldCompress {
return r.gzipResponseWriter.Write(b)
}

View file

@ -3,8 +3,11 @@ package gzip
import (
"compress/gzip"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/mholt/caddy/middleware"
)
func TestLengthFilter(t *testing.T) {
@ -30,7 +33,8 @@ func TestLengthFilter(t *testing.T) {
for j, filter := range filters {
r := httptest.NewRecorder()
r.Header().Set("Content-Length", fmt.Sprint(ts.length))
if filter.ShouldCompress(r) != ts.shouldCompress[j] {
wWriter := NewResponseFilterWriter([]ResponseFilter{filter}, gzipResponseWriter{gzip.NewWriter(r), r})
if filter.ShouldCompress(wWriter) != ts.shouldCompress[j] {
t.Errorf("Test %v: Expected %v found %v", i, ts.shouldCompress[j], filter.ShouldCompress(r))
}
}
@ -47,16 +51,32 @@ func TestResponseFilterWriter(t *testing.T) {
{"Hello \t\t\nfrom gzip", true},
{"Hello gzip\n", false},
}
filters := []ResponseFilter{
LengthFilter(15),
}
server := Gzip{Configs: []Config{
{ResponseFilters: filters},
}}
for i, ts := range tests {
w := httptest.NewRecorder()
server.Next = middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
w.Header().Set("Content-Length", fmt.Sprint(len(ts.body)))
gz := gzipResponseWriter{gzip.NewWriter(w), w}
rw := NewResponseFilterWriter(filters, gz)
rw.Write([]byte(ts.body))
w.WriteHeader(200)
w.Write([]byte(ts.body))
return 200, nil
})
r := urlRequest("/")
r.Header.Set("Accept-Encoding", "gzip")
w := httptest.NewRecorder()
server.ServeHTTP(w, r)
resp := w.Body.String()
if !ts.shouldCompress {
if resp != ts.body {
t.Errorf("Test %v: No compression expected, found %v", i, resp)