diff --git a/caddy/setup/rewrite.go b/caddy/setup/rewrite.go index b510a237..c70a2f94 100644 --- a/caddy/setup/rewrite.go +++ b/caddy/setup/rewrite.go @@ -1,6 +1,8 @@ package setup import ( + "net/http" + "github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware/rewrite" ) @@ -13,7 +15,11 @@ func Rewrite(c *Controller) (middleware.Middleware, error) { } return func(next middleware.Handler) middleware.Handler { - return rewrite.Rewrite{Next: next, Rules: rewrites} + return rewrite.Rewrite{ + Next: next, + FileSys: http.Dir(c.Root), + Rules: rewrites, + } }, nil } @@ -30,6 +36,8 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) { args := c.RemainingArgs() + var ifs []rewrite.If + switch len(args) { case 2: rule = rewrite.NewSimpleRule(args[0], args[1]) @@ -56,6 +64,16 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) { return nil, c.ArgErr() } ext = args1 + case "if": + args1 := c.RemainingArgs() + if len(args1) != 3 { + return nil, c.ArgErr() + } + ifCond, err := rewrite.NewIf(args1[0], args1[1], args1[2]) + if err != nil { + return nil, err + } + ifs = append(ifs, ifCond) default: return nil, c.ArgErr() } @@ -64,7 +82,7 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) { if pattern == "" || to == "" { return nil, c.ArgErr() } - if rule, err = rewrite.NewRegexpRule(base, pattern, to, ext); err != nil { + if rule, err = rewrite.NewComplexRule(base, pattern, to, ext, ifs); err != nil { return nil, err } regexpRules = append(regexpRules, rule) diff --git a/caddy/setup/rewrite_test.go b/caddy/setup/rewrite_test.go index f5426678..ddef7cd4 100644 --- a/caddy/setup/rewrite_test.go +++ b/caddy/setup/rewrite_test.go @@ -98,14 +98,14 @@ func TestRewriteParse(t *testing.T) { r .* to /to }`, false, []rewrite.Rule{ - &rewrite.RegexpRule{Base: "/", To: "/to", Regexp: regexp.MustCompile(".*")}, + &rewrite.ComplexRule{Base: "/", To: "/to", Regexp: regexp.MustCompile(".*")}, }}, {`rewrite { regexp .* to /to ext / html txt }`, false, []rewrite.Rule{ - &rewrite.RegexpRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")}, + &rewrite.ComplexRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")}, }}, {`rewrite /path { r rr @@ -116,26 +116,26 @@ func TestRewriteParse(t *testing.T) { to /to } `, false, []rewrite.Rule{ - &rewrite.RegexpRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")}, - &rewrite.RegexpRule{Base: "/", To: "/to", Regexp: regexp.MustCompile("[a-z]+")}, + &rewrite.ComplexRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")}, + &rewrite.ComplexRule{Base: "/", To: "/to", Regexp: regexp.MustCompile("[a-z]+")}, }}, {`rewrite { to /to }`, true, []rewrite.Rule{ - &rewrite.RegexpRule{}, + &rewrite.ComplexRule{}, }}, {`rewrite { r .* }`, true, []rewrite.Rule{ - &rewrite.RegexpRule{}, + &rewrite.ComplexRule{}, }}, {`rewrite { }`, true, []rewrite.Rule{ - &rewrite.RegexpRule{}, + &rewrite.ComplexRule{}, }}, {`rewrite /`, true, []rewrite.Rule{ - &rewrite.RegexpRule{}, + &rewrite.ComplexRule{}, }}, } @@ -157,8 +157,8 @@ func TestRewriteParse(t *testing.T) { } for j, e := range test.expected { - actualRule := actual[j].(*rewrite.RegexpRule) - expectedRule := e.(*rewrite.RegexpRule) + actualRule := actual[j].(*rewrite.ComplexRule) + expectedRule := e.(*rewrite.ComplexRule) if actualRule.Base != expectedRule.Base { t.Errorf("Test %d, rule %d: Expected Base=%s, got %s", diff --git a/middleware/rewrite/condition.go b/middleware/rewrite/condition.go new file mode 100644 index 00000000..ab69ef4a --- /dev/null +++ b/middleware/rewrite/condition.go @@ -0,0 +1,110 @@ +package rewrite + +import ( + "fmt" + "github.com/mholt/caddy/middleware" + "net/http" + "regexp" + "strings" +) + +const ( + // Operators + Is = "is" + Not = "not" + Has = "has" + StartsWith = "starts_with" + EndsWith = "ends_with" + Match = "match" +) + +func operatorError(operator string) error { + return fmt.Errorf("Invalid operator", operator) +} + +func newReplacer(r *http.Request) middleware.Replacer { + return middleware.NewReplacer(r, nil, "") +} + +// condition is a rewrite condition. +type condition func(string, string) bool + +var conditions = map[string]condition{ + Is: isFunc, + Not: notFunc, + Has: hasFunc, + StartsWith: startsWithFunc, + EndsWith: endsWithFunc, + Match: matchFunc, +} + +// isFunc is condition for Is operator. +// It checks for equality. +func isFunc(a, b string) bool { + return a == b +} + +// notFunc is condition for Not operator. +// It checks for inequality. +func notFunc(a, b string) bool { + return a != b +} + +// hasFunc is condition for Has operator. +// It checks if b is a substring of a. +func hasFunc(a, b string) bool { + return strings.Contains(a, b) +} + +// startsWithFunc is condition for StartsWith operator. +// It checks if b is a prefix of a. +func startsWithFunc(a, b string) bool { + return strings.HasPrefix(a, b) +} + +// endsWithFunc is condition for EndsWith operator. +// It checks if b is a suffix of a. +func endsWithFunc(a, b string) bool { + return strings.HasSuffix(a, b) +} + +// matchFunc is condition for Match operator. +// It does regexp matching of +func matchFunc(a, b string) bool { + matched, _ := regexp.MatchString(b, a) + return matched +} + +// If is statement for a rewrite condition. +type If struct { + A string + Operator string + B string +} + +// True returns true if the condition is true and false otherwise. +// If r is not nil, it replaces placeholders before comparison. +func (i If) True(r *http.Request) bool { + if c, ok := conditions[i.Operator]; ok { + a, b := i.A, i.B + if r != nil { + replacer := newReplacer(r) + a = replacer.Replace(i.A) + b = replacer.Replace(i.B) + } + return c(a, b) + } + return false +} + +// NewIf creates a new If condition. +func NewIf(a, operator, b string) (If, error) { + if _, ok := conditions[operator]; !ok { + return If{}, operatorError(operator) + } + return If{ + A: a, + Operator: operator, + B: b, + }, nil +} diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go index 88944f73..cc60e82c 100644 --- a/middleware/rewrite/rewrite.go +++ b/middleware/rewrite/rewrite.go @@ -5,7 +5,6 @@ package rewrite import ( "fmt" "net/http" - "net/url" "path" "path/filepath" "regexp" @@ -16,14 +15,15 @@ import ( // Rewrite is middleware to rewrite request locations internally before being handled. type Rewrite struct { - Next middleware.Handler - Rules []Rule + Next middleware.Handler + FileSys http.FileSystem + Rules []Rule } // ServeHTTP implements the middleware.Handler interface. func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { for _, rule := range rw.Rules { - if ok := rule.Rewrite(r); ok { + if ok := rule.Rewrite(rw.FileSys, r); ok { break } } @@ -33,7 +33,7 @@ func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) // Rule describes an internal location rewrite rule. type Rule interface { // Rewrite rewrites the internal location of the current request. - Rewrite(*http.Request) bool + Rewrite(http.FileSystem, *http.Request) bool } // SimpleRule is a simple rewrite rule. @@ -47,23 +47,20 @@ func NewSimpleRule(from, to string) SimpleRule { } // Rewrite rewrites the internal location of the current request. -func (s SimpleRule) Rewrite(r *http.Request) bool { +func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) bool { if s.From == r.URL.Path { // take note of this rewrite for internal use by fastcgi // all we need is the URI, not full URL r.Header.Set(headerFieldName, r.URL.RequestURI()) - // replace variables - to := path.Clean(middleware.NewReplacer(r, nil, "").Replace(s.To)) - - r.URL.Path = to - return true + // attempt rewrite + return To(fs, r, s.To) } return false } -// RegexpRule is a rewrite rule based on a regular expression -type RegexpRule struct { +// ComplexRule is a rewrite rule based on a regular expression +type ComplexRule struct { // Path base. Request to this path and subpaths will be rewritten Base string @@ -73,18 +70,26 @@ type RegexpRule struct { // Extensions to filter by Exts []string + // Rewrite conditions + Ifs []If + *regexp.Regexp } // NewRegexpRule creates a new RegexpRule. It returns an error if regexp // pattern (pattern) or extensions (ext) are invalid. -func NewRegexpRule(base, pattern, to string, ext []string) (*RegexpRule, error) { - r, err := regexp.Compile(pattern) - if err != nil { - return nil, err +func NewComplexRule(base, pattern, to string, ext []string, ifs []If) (*ComplexRule, error) { + // validate regexp if present + var r *regexp.Regexp + if pattern != "" { + var err error + r, err = regexp.Compile(pattern) + if err != nil { + return nil, err + } } - // validate extensions + // validate extensions if present for _, v := range ext { if len(v) < 2 || (len(v) < 3 && v[0] == '!') { // check if no extension is specified @@ -94,16 +99,17 @@ func NewRegexpRule(base, pattern, to string, ext []string) (*RegexpRule, error) } } - return &RegexpRule{ - base, - to, - ext, - r, + return &ComplexRule{ + Base: base, + To: to, + Exts: ext, + Ifs: ifs, + Regexp: r, }, nil } // Rewrite rewrites the internal location of the current request. -func (r *RegexpRule) Rewrite(req *http.Request) bool { +func (r *ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) bool { rPath := req.URL.Path // validate base @@ -127,31 +133,13 @@ func (r *RegexpRule) Rewrite(req *http.Request) bool { return false } - // replace variables - to := path.Clean(middleware.NewReplacer(req, nil, "").Replace(r.To)) - - // validate resulting path - url, err := url.Parse(to) - if err != nil { - return false - } - - // take note of this rewrite for internal use by fastcgi - // all we need is the URI, not full URL - req.Header.Set(headerFieldName, req.URL.RequestURI()) - - // perform rewrite - req.URL.Path = url.Path - if url.RawQuery != "" { - // overwrite query string if present - req.URL.RawQuery = url.RawQuery - } - return true + // attempt rewrite + return To(fs, req, r.To) } // matchExt matches rPath against registered file extensions. // Returns true if a match is found and false otherwise. -func (r *RegexpRule) matchExt(rPath string) bool { +func (r *ComplexRule) matchExt(rPath string) bool { f := filepath.Base(rPath) ext := path.Ext(f) if ext == "" { diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go index fb047026..ca4cc512 100644 --- a/middleware/rewrite/rewrite_test.go +++ b/middleware/rewrite/rewrite_test.go @@ -4,9 +4,8 @@ import ( "fmt" "net/http" "net/http/httptest" - "testing" - "strings" + "testing" "github.com/mholt/caddy/middleware" ) @@ -38,7 +37,7 @@ func TestRewrite(t *testing.T) { if s := strings.Split(regexpRule[3], "|"); len(s) > 1 { ext = s[:len(s)-1] } - rule, err := NewRegexpRule(regexpRule[0], regexpRule[1], regexpRule[2], ext) + rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], ext) if err != nil { t.Fatal(err) } diff --git a/middleware/rewrite/to.go b/middleware/rewrite/to.go new file mode 100644 index 00000000..0ee8a159 --- /dev/null +++ b/middleware/rewrite/to.go @@ -0,0 +1,68 @@ +package rewrite + +import ( + "log" + "net/http" + "net/url" + "strings" +) + +// To attempts rewrite. It attempts to rewrite to first valid path +// or the last path if none of the paths are valid. +// Returns true if rewrite is successful and false otherwise. +func To(fs http.FileSystem, r *http.Request, to string) bool { + tos := strings.Fields(to) + replacer := newReplacer(r) + + // try each rewrite paths + t := "" + for _, v := range tos { + t = replacer.Replace(v) + if isValidFile(fs, t) { + break + } + } + + // validate resulting path + u, err := url.Parse(t) + if err != nil { + // Let the user know we got here. Rewrite is expected but + // the resulting url is invalid. + log.Printf("[ERROR] rewrite: resulting path '%v' is invalid. error: %v", t, err) + return false + } + + // take note of this rewrite for internal use by fastcgi + // all we need is the URI, not full URL + r.Header.Set(headerFieldName, r.URL.RequestURI()) + + // perform rewrite + r.URL.Path = u.Path + if u.RawQuery != "" { + // overwrite query string if present + r.URL.RawQuery = u.RawQuery + } + if u.Fragment != "" { + // overwrite fragment if present + r.URL.Fragment = u.Fragment + } + + return true +} + +// isValidFile checks if file exists on the filesystem. +// if file ends with `/`, it is validated as a directory. +func isValidFile(fs http.FileSystem, file string) bool { + f, err := fs.Open(file) + if err != nil { + return false + } + defer f.Close() + + stat, err := f.Stat() + if err != nil { + return false + } + + return strings.HasSuffix(file, "/") && stat.IsDir() +}