diff --git a/config/setup/gzip.go b/config/setup/gzip.go index aa294de2..714f8198 100644 --- a/config/setup/gzip.go +++ b/config/setup/gzip.go @@ -1,13 +1,84 @@ package setup import ( + "fmt" + "strconv" + "strings" + "github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware/gzip" ) // Gzip configures a new gzip middleware instance. func Gzip(c *Controller) (middleware.Middleware, error) { + configs, err := gzipParse(c) + if err != nil { + return nil, err + } + return func(next middleware.Handler) middleware.Handler { - return gzip.Gzip{Next: next} + return gzip.Gzip{Next: next, Configs: configs} }, nil } + +func gzipParse(c *Controller) ([]gzip.Config, error) { + var configs []gzip.Config + + for c.Next() { + config := gzip.Config{} + + pathFilter := gzip.PathFilter{make(gzip.Set)} + extFilter := gzip.DefaultExtFilter() + + // no extra args expected + if len(c.RemainingArgs()) > 0 { + return configs, c.ArgErr() + } + + for c.NextBlock() { + switch c.Val() { + case "ext": + exts := c.RemainingArgs() + if len(exts) == 0 { + return configs, c.ArgErr() + } + for _, e := range exts { + if !strings.HasPrefix(e, ".") { + return configs, fmt.Errorf(`Invalid extension %v. Should start with "."`, e) + } + extFilter.Exts.Add(e) + } + case "not": + paths := c.RemainingArgs() + if len(paths) == 0 { + return configs, c.ArgErr() + } + for _, p := range paths { + if !strings.HasPrefix(p, "/") { + return configs, fmt.Errorf(`Invalid path %v. Should start with "/"`, p) + } + pathFilter.IgnoredPaths.Add(p) + // Warn user if / is used + if p == "/" { + fmt.Println("Warning: Paths ignored by gzip includes wildcard(/). No request will be gzipped.\nRemoving gzip directive from Caddyfile is preferred if this is intended.") + } + } + case "level": + if !c.NextArg() { + return configs, c.ArgErr() + } + level, _ := strconv.Atoi(c.Val()) + config.Level = level + default: + return configs, c.ArgErr() + } + } + + // put pathFilter in front to filter with path first + config.Filters = []gzip.Filter{pathFilter, extFilter} + + configs = append(configs, config) + } + + return configs, nil +} diff --git a/config/setup/gzip_test.go b/config/setup/gzip_test.go index d9978184..accede6a 100644 --- a/config/setup/gzip_test.go +++ b/config/setup/gzip_test.go @@ -26,4 +26,47 @@ func TestGzip(t *testing.T) { if !sameNext(myHandler.Next, emptyNext) { t.Error("'Next' field of handler was not set properly") } + + tests := []struct { + input string + shouldErr bool + }{ + {`gzip {`, true}, + {`gzip {}`, true}, + {`gzip a b`, true}, + {`gzip a {`, true}, + {`gzip { not f } `, true}, + {`gzip { not } `, true}, + {`gzip { not /file + ext .html + level 1 + } `, false}, + {`gzip { level 9 } `, false}, + {`gzip { ext } `, true}, + {`gzip { ext /f + } `, true}, + {`gzip { not /file + ext .html + level 1 + } + gzip`, false}, + {`gzip { not /file + ext .html + level 1 + } + gzip { not /file1 + ext .htm + level 3 + } + `, false}, + } + for i, test := range tests { + c := newTestController(test.input) + _, err := gzipParse(c) + if test.shouldErr && err == nil { + t.Errorf("Text %v: Expected error but found nil", i) + } else if !test.shouldErr && err != nil { + t.Errorf("Text %v: Expected no error but found error: ", i, err) + } + } } diff --git a/middleware/gzip/filter.go b/middleware/gzip/filter.go new file mode 100644 index 00000000..517f8858 --- /dev/null +++ b/middleware/gzip/filter.go @@ -0,0 +1,90 @@ +package gzip + +import ( + "net/http" + "path" + + "github.com/mholt/caddy/middleware" +) + +// Filter determines if a request should be gzipped. +type Filter interface { + // ShouldCompress tells if compression gzip compression + // should be done on the request. + ShouldCompress(*http.Request) bool +} + +// ExtFilter is Filter for file name extensions. +type ExtFilter struct { + // Exts is the file name extensions to accept + Exts Set +} + +// textExts is a list of extensions for text related files. +var textExts = []string{ + ".html", ".htm", ".css", ".json", ".php", ".js", ".txt", ".md", ".xml", +} + +// extWildCard is the wildcard for extensions. +const extWildCard = "*" + +// DefaultExtFilter creates a default ExtFilter with +// file extensions for text types. +func DefaultExtFilter() ExtFilter { + e := ExtFilter{make(Set)} + for _, ext := range textExts { + e.Exts.Add(ext) + } + return e +} + +func (e ExtFilter) ShouldCompress(r *http.Request) bool { + ext := path.Ext(r.URL.Path) + return e.Exts.Contains(extWildCard) || e.Exts.Contains(ext) +} + +// PathFilter is Filter for request path. +type PathFilter struct { + // IgnoredPaths is the paths to ignore + IgnoredPaths Set +} + +// ShouldCompress checks if the request path matches any of the +// registered paths to ignore. If returns false if an ignored path +// is found and true otherwise. +func (p PathFilter) ShouldCompress(r *http.Request) bool { + return !p.IgnoredPaths.ContainsFunc(func(value string) bool { + return middleware.Path(r.URL.Path).Matches(value) + }) +} + +// Set stores distinct strings. +type Set map[string]struct{} + +// Add adds an element to the set. +func (s Set) Add(value string) { + s[value] = struct{}{} +} + +// Remove removes an element from the set. +func (s Set) Remove(value string) { + delete(s, value) +} + +// Contains check if the set contains value. +func (s Set) Contains(value string) bool { + _, ok := s[value] + return ok +} + +// ContainsFunc is similar to Contains. It iterates all the +// elements in the set and passes each to f. It returns true +// on the first call to f that returns true and false otherwise. +func (s Set) ContainsFunc(f func(string) bool) bool { + for k, _ := range s { + if f(k) { + return true + } + } + return false +} diff --git a/middleware/gzip/filter_test.go b/middleware/gzip/filter_test.go new file mode 100644 index 00000000..56d054cf --- /dev/null +++ b/middleware/gzip/filter_test.go @@ -0,0 +1,106 @@ +package gzip + +import ( + "net/http" + "testing" +) + +func TestSet(t *testing.T) { + set := make(Set) + set.Add("a") + if len(set) != 1 { + t.Errorf("Expected 1 found %v", len(set)) + } + set.Add("a") + if len(set) != 1 { + t.Errorf("Expected 1 found %v", len(set)) + } + set.Add("b") + if len(set) != 2 { + t.Errorf("Expected 2 found %v", len(set)) + } + if !set.Contains("a") { + t.Errorf("Set should contain a") + } + if !set.Contains("b") { + t.Errorf("Set should contain a") + } + set.Add("c") + if len(set) != 3 { + t.Errorf("Expected 3 found %v", len(set)) + } + if !set.Contains("c") { + t.Errorf("Set should contain c") + } + set.Remove("a") + if len(set) != 2 { + t.Errorf("Expected 2 found %v", len(set)) + } + if set.Contains("a") { + t.Errorf("Set should not contain a") + } + if !set.ContainsFunc(func(v string) bool { + return v == "c" + }) { + t.Errorf("ContainsFunc should return true") + } +} + +func TestExtFilter(t *testing.T) { + var filter Filter = DefaultExtFilter() + _ = filter.(ExtFilter) + for i, e := range textExts { + r := urlRequest("file" + e) + if !filter.ShouldCompress(r) { + t.Errorf("Test %v: Should be valid filter", i) + } + } + var exts = []string{ + ".html", ".css", ".md", + } + for i, e := range exts { + r := urlRequest("file" + e) + if !filter.ShouldCompress(r) { + t.Errorf("Test %v: Should be valid filter", i) + } + } + exts = []string{ + ".htm1", ".abc", ".mdx", + } + for i, e := range exts { + r := urlRequest("file" + e) + if filter.ShouldCompress(r) { + t.Errorf("Test %v: Should not be valid filter", i) + } + } +} + +func TestPathFilter(t *testing.T) { + paths := []string{ + "/a", "/b", "/c", "/de", + } + var filter Filter = PathFilter{make(Set)} + for _, p := range paths { + filter.(PathFilter).IgnoredPaths.Add(p) + } + for i, p := range paths { + r := urlRequest(p) + if filter.ShouldCompress(r) { + t.Errorf("Test %v: Should not be valid filter", i) + } + } + paths = []string{ + "/f", "/g", "/h", "/ed", + } + for i, p := range paths { + r := urlRequest(p) + if !filter.ShouldCompress(r) { + t.Errorf("Test %v: Should be valid filter", i) + } + } +} + +func urlRequest(url string) *http.Request { + r, _ := http.NewRequest("GET", url, nil) + return r +} diff --git a/middleware/gzip/gzip.go b/middleware/gzip/gzip.go index 803612ef..b9663144 100644 --- a/middleware/gzip/gzip.go +++ b/middleware/gzip/gzip.go @@ -17,7 +17,14 @@ import ( // specifies the Content-Type, otherwise some clients will assume // application/x-gzip and try to download a file. type Gzip struct { - Next middleware.Handler + Next middleware.Handler + Configs []Config +} + +// Config holds the configuration for Gzip middleware +type Config struct { + Filters []Filter // Filters to use + Level int // Compression level } // ServeHTTP serves a gzipped response if the client supports it. @@ -26,27 +33,56 @@ func (g Gzip) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { return g.Next.ServeHTTP(w, r) } - // Delete this header so gzipping isn't repeated later in the chain - r.Header.Del("Accept-Encoding") +outer: + for _, c := range g.Configs { - w.Header().Set("Content-Encoding", "gzip") - gzipWriter := gzip.NewWriter(w) - defer gzipWriter.Close() - gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w} + // Check filters to determine if gzipping is permitted for this + // request + for _, filter := range c.Filters { + if !filter.ShouldCompress(r) { + continue outer + } + } - // Any response in forward middleware will now be compressed - status, err := g.Next.ServeHTTP(gz, r) + // Delete this header so gzipping is not repeated later in the chain + r.Header.Del("Accept-Encoding") - // If there was an error that remained unhandled, we need - // to send something back before gzipWriter gets closed at - // the return of this method! - if status >= 400 { - gz.Header().Set("Content-Type", "text/plain") // very necessary - gz.WriteHeader(status) - fmt.Fprintf(gz, "%d %s", status, http.StatusText(status)) - return 0, err + w.Header().Set("Content-Encoding", "gzip") + gzipWriter, err := newWriter(c, w) + if err != nil { + // should not happen + return http.StatusInternalServerError, err + } + defer gzipWriter.Close() + gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w} + + // Any response in forward middleware will now be compressed + status, err := g.Next.ServeHTTP(gz, r) + + // If there was an error that remained unhandled, we need + // to send something back before gzipWriter gets closed at + // the return of this method! + if status >= 400 { + gz.Header().Set("Content-Type", "text/plain") // very necessary + gz.WriteHeader(status) + fmt.Fprintf(gz, "%d %s", status, http.StatusText(status)) + return 0, err + } + return status, err } - return status, err + + // no matching filter + return g.Next.ServeHTTP(w, r) +} + +// newWriter create a new Gzip Writer based on the compression level. +// If the level is valid (i.e. between 1 and 9), it uses the level. +// Otherwise, it uses default compression level. +func newWriter(c Config, w http.ResponseWriter) (*gzip.Writer, error) { + if c.Level >= gzip.BestSpeed && c.Level <= gzip.BestCompression { + return gzip.NewWriterLevel(w, c.Level) + } + return gzip.NewWriter(w), nil } // gzipResponeWriter wraps the underlying Write method diff --git a/middleware/gzip/gzip_test.go b/middleware/gzip/gzip_test.go new file mode 100644 index 00000000..4039d14d --- /dev/null +++ b/middleware/gzip/gzip_test.go @@ -0,0 +1,100 @@ +package gzip + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mholt/caddy/middleware" +) + +func Test(t *testing.T) { + + pathFilter := PathFilter{make(Set)} + badPaths := []string{"/bad", "/nogzip", "/nongzip"} + for _, p := range badPaths { + pathFilter.IgnoredPaths.Add(p) + } + gz := Gzip{Configs: []Config{ + Config{Filters: []Filter{DefaultExtFilter(), pathFilter}}, + }} + + w := httptest.NewRecorder() + gz.Next = nextFunc(true) + for _, e := range textExts { + url := "/file" + e + r, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Error(err) + } + r.Header.Set("Accept-Encoding", "gzip") + _, err = gz.ServeHTTP(w, r) + if err != nil { + t.Error(err) + } + } + + w = httptest.NewRecorder() + gz.Next = nextFunc(false) + for _, p := range badPaths { + for _, e := range textExts { + url := p + "/file" + e + r, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Error(err) + } + r.Header.Set("Accept-Encoding", "gzip") + _, err = gz.ServeHTTP(w, r) + if err != nil { + t.Error(err) + } + } + } + + w = httptest.NewRecorder() + gz.Next = nextFunc(false) + exts := []string{ + ".htm1", ".abc", ".mdx", + } + for _, e := range exts { + url := "/file" + e + r, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Error(err) + } + r.Header.Set("Accept-Encoding", "gzip") + _, err = gz.ServeHTTP(w, r) + if err != nil { + t.Error(err) + } + } + +} + +func nextFunc(shouldGzip bool) middleware.Handler { + return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + if shouldGzip { + if r.Header.Get("Accept-Encoding") != "" { + return 0, fmt.Errorf("Accept-Encoding header not expected") + } + if w.Header().Get("Content-Encoding") != "gzip" { + return 0, fmt.Errorf("Content-Encoding must be gzip, found %v", r.Header.Get("Content-Encoding")) + } + if _, ok := w.(gzipResponseWriter); !ok { + return 0, fmt.Errorf("ResponseWriter should be gzipResponseWriter, found %T", w) + } + return 0, nil + } + if r.Header.Get("Accept-Encoding") == "" { + return 0, fmt.Errorf("Accept-Encoding header expected") + } + if w.Header().Get("Content-Encoding") == "gzip" { + return 0, fmt.Errorf("Content-Encoding must not be gzip, found gzip") + } + if _, ok := w.(gzipResponseWriter); ok { + return 0, fmt.Errorf("ResponseWriter should not be gzipResponseWriter") + } + return 0, nil + }) +}