Merge branch 'master' into replacer-patch

This commit is contained in:
Abiola Ibrahim 2015-12-30 20:26:11 +01:00
commit 73327e784d
16 changed files with 502 additions and 89 deletions

View file

@ -0,0 +1,6 @@
glob0.host0 {
dir2 arg1
}
glob0.host1 {
}

View file

@ -0,0 +1,4 @@
glob1.host0 {
dir1
dir2 arg1
}

View file

@ -0,0 +1,3 @@
glob2.host0 {
dir2 arg1
}

View file

@ -176,19 +176,52 @@ func (p *parser) directives() error {
} }
// doImport swaps out the import directive and its argument // doImport swaps out the import directive and its argument
// (a total of 2 tokens) with the tokens in the file specified. // (a total of 2 tokens) with the tokens in the specified file
// When the function returns, the cursor is on the token before // or globbing pattern. When the function returns, the cursor
// where the import directive was. In other words, call Next() // is on the token before where the import directive was. In
// to access the first token that was imported. // other words, call Next() to access the first token that was
// imported.
func (p *parser) doImport() error { func (p *parser) doImport() error {
if !p.NextArg() { if !p.NextArg() {
return p.ArgErr() return p.ArgErr()
} }
importFile := p.Val() importPattern := p.Val()
if p.NextArg() { if p.NextArg() {
return p.Err("Import allows only one file to import") return p.Err("Import allows only one expression, either file or glob pattern")
} }
matches, err := filepath.Glob(importPattern)
if err != nil {
return p.Errf("Failed to use import pattern %s - %s", importPattern, err.Error())
}
if len(matches) == 0 {
return p.Errf("No files matching the import pattern %s", importPattern)
}
// Splice out the import directive and its argument (2 tokens total)
// and insert the imported tokens in their place.
tokensBefore := p.tokens[:p.cursor-1]
tokensAfter := p.tokens[p.cursor+1:]
// cursor was advanced one position to read filename; rewind it
p.cursor--
p.tokens = tokensBefore
for _, importFile := range matches {
if err := p.doSingleImport(importFile); err != nil {
return err
}
}
p.tokens = append(p.tokens, append(tokensAfter)...)
return nil
}
// doSingleImport lexes the individual files matching the
// globbing pattern from of the import directive.
func (p *parser) doSingleImport(importFile string) error {
file, err := os.Open(importFile) file, err := os.Open(importFile)
if err != nil { if err != nil {
return p.Errf("Could not import %s - %v", importFile, err) return p.Errf("Could not import %s - %v", importFile, err)
@ -203,10 +236,7 @@ func (p *parser) doImport() error {
// Splice out the import directive and its argument (2 tokens total) // Splice out the import directive and its argument (2 tokens total)
// and insert the imported tokens in their place. // and insert the imported tokens in their place.
tokensBefore := p.tokens[:p.cursor-1] p.tokens = append(p.tokens, append(importedTokens)...)
tokensAfter := p.tokens[p.cursor+1:]
p.tokens = append(tokensBefore, append(importedTokens, tokensAfter...)...)
p.cursor-- // cursor was advanced one position to read the filename; rewind it
return nil return nil
} }

View file

@ -329,6 +329,13 @@ func TestParseAll(t *testing.T) {
[]address{{"host1.com", "http"}, {"host2.com", "http"}}, []address{{"host1.com", "http"}, {"host2.com", "http"}},
[]address{{"host3.com", "https"}, {"host4.com", "https"}}, []address{{"host3.com", "https"}, {"host4.com", "https"}},
}}, }},
{`import import_glob*.txt`, false, [][]address{
[]address{{"glob0.host0", ""}},
[]address{{"glob0.host1", ""}},
[]address{{"glob1.host0", ""}},
[]address{{"glob2.host0", ""}},
}},
} { } {
p := testParser(test.input) p := testParser(test.input)
blocks, err := p.parseAll() blocks, err := p.parseAll()

View file

@ -1,6 +1,9 @@
package setup package setup
import ( import (
"net/http"
"strings"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
"github.com/mholt/caddy/middleware/rewrite" "github.com/mholt/caddy/middleware/rewrite"
) )
@ -13,7 +16,11 @@ func Rewrite(c *Controller) (middleware.Middleware, error) {
} }
return func(next middleware.Handler) middleware.Handler { return func(next middleware.Handler) middleware.Handler {
return rewrite.Rewrite{Next: next, Rules: rewrites} return rewrite.Rewrite{
Next: next,
FileSys: http.Dir(c.Root),
Rules: rewrites,
}
}, nil }, nil
} }
@ -30,6 +37,8 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) {
args := c.RemainingArgs() args := c.RemainingArgs()
var ifs []rewrite.If
switch len(args) { switch len(args) {
case 2: case 2:
rule = rewrite.NewSimpleRule(args[0], args[1]) rule = rewrite.NewSimpleRule(args[0], args[1])
@ -46,25 +55,36 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) {
} }
pattern = c.Val() pattern = c.Val()
case "to": case "to":
if !c.NextArg() { args1 := c.RemainingArgs()
if len(args1) == 0 {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
to = c.Val() to = strings.Join(args1, " ")
case "ext": case "ext":
args1 := c.RemainingArgs() args1 := c.RemainingArgs()
if len(args1) == 0 { if len(args1) == 0 {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
ext = args1 ext = args1
case "if":
args1 := c.RemainingArgs()
if len(args1) != 3 {
return nil, c.ArgErr()
}
ifCond, err := rewrite.NewIf(args1[0], args1[1], args1[2])
if err != nil {
return nil, err
}
ifs = append(ifs, ifCond)
default: default:
return nil, c.ArgErr() return nil, c.ArgErr()
} }
} }
// ensure pattern and to are specified // ensure to is specified
if pattern == "" || to == "" { if to == "" {
return nil, c.ArgErr() return nil, c.ArgErr()
} }
if rule, err = rewrite.NewRegexpRule(base, pattern, to, ext); err != nil { if rule, err = rewrite.NewComplexRule(base, pattern, to, ext, ifs); err != nil {
return nil, err return nil, err
} }
regexpRules = append(regexpRules, rule) regexpRules = append(regexpRules, rule)

View file

@ -1,10 +1,9 @@
package setup package setup
import ( import (
"testing"
"fmt" "fmt"
"regexp" "regexp"
"testing"
"github.com/mholt/caddy/middleware/rewrite" "github.com/mholt/caddy/middleware/rewrite"
) )
@ -96,16 +95,16 @@ func TestRewriteParse(t *testing.T) {
}{ }{
{`rewrite { {`rewrite {
r .* r .*
to /to to /to /index.php?
}`, false, []rewrite.Rule{ }`, false, []rewrite.Rule{
&rewrite.RegexpRule{Base: "/", To: "/to", Regexp: regexp.MustCompile(".*")}, &rewrite.ComplexRule{Base: "/", To: "/to /index.php?", Regexp: regexp.MustCompile(".*")},
}}, }},
{`rewrite { {`rewrite {
regexp .* regexp .*
to /to to /to
ext / html txt ext / html txt
}`, false, []rewrite.Rule{ }`, false, []rewrite.Rule{
&rewrite.RegexpRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")}, &rewrite.ComplexRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")},
}}, }},
{`rewrite /path { {`rewrite /path {
r rr r rr
@ -113,29 +112,30 @@ func TestRewriteParse(t *testing.T) {
} }
rewrite / { rewrite / {
regexp [a-z]+ regexp [a-z]+
to /to to /to /to2
} }
`, false, []rewrite.Rule{ `, false, []rewrite.Rule{
&rewrite.RegexpRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")}, &rewrite.ComplexRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")},
&rewrite.RegexpRule{Base: "/", To: "/to", Regexp: regexp.MustCompile("[a-z]+")}, &rewrite.ComplexRule{Base: "/", To: "/to /to2", Regexp: regexp.MustCompile("[a-z]+")},
}},
{`rewrite {
to /to
}`, true, []rewrite.Rule{
&rewrite.RegexpRule{},
}}, }},
{`rewrite { {`rewrite {
r .* r .*
}`, true, []rewrite.Rule{ }`, true, []rewrite.Rule{
&rewrite.RegexpRule{}, &rewrite.ComplexRule{},
}}, }},
{`rewrite { {`rewrite {
}`, true, []rewrite.Rule{ }`, true, []rewrite.Rule{
&rewrite.RegexpRule{}, &rewrite.ComplexRule{},
}}, }},
{`rewrite /`, true, []rewrite.Rule{ {`rewrite /`, true, []rewrite.Rule{
&rewrite.RegexpRule{}, &rewrite.ComplexRule{},
}},
{`rewrite {
to /to
if {path} is a
}`, false, []rewrite.Rule{
&rewrite.ComplexRule{Base: "/", To: "/to", Ifs: []rewrite.If{rewrite.If{A: "{path}", Operator: "is", B: "a"}}},
}}, }},
} }
@ -157,8 +157,8 @@ func TestRewriteParse(t *testing.T) {
} }
for j, e := range test.expected { for j, e := range test.expected {
actualRule := actual[j].(*rewrite.RegexpRule) actualRule := actual[j].(*rewrite.ComplexRule)
expectedRule := e.(*rewrite.RegexpRule) expectedRule := e.(*rewrite.ComplexRule)
if actualRule.Base != expectedRule.Base { if actualRule.Base != expectedRule.Base {
t.Errorf("Test %d, rule %d: Expected Base=%s, got %s", t.Errorf("Test %d, rule %d: Expected Base=%s, got %s",
@ -175,10 +175,18 @@ func TestRewriteParse(t *testing.T) {
i, j, expectedRule.To, actualRule.To) i, j, expectedRule.To, actualRule.To)
} }
if actualRule.String() != expectedRule.String() { if actualRule.Regexp != nil {
t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", if actualRule.String() != expectedRule.String() {
i, j, expectedRule.String(), actualRule.String()) t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s",
i, j, expectedRule.String(), actualRule.String())
}
} }
if fmt.Sprint(actualRule.Ifs) != fmt.Sprint(expectedRule.Ifs) {
t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s",
i, j, fmt.Sprint(expectedRule.Ifs), fmt.Sprint(actualRule.Ifs))
}
} }
} }

View file

@ -5,6 +5,8 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -103,8 +105,12 @@ func generateStaticHTML(md Markdown, cfg *Config) error {
reqPath = filepath.ToSlash(reqPath) reqPath = filepath.ToSlash(reqPath)
reqPath = "/" + reqPath reqPath = "/" + reqPath
// Create empty requests and url to cater for template values.
req, _ := http.NewRequest("", "/", nil)
urlVar, _ := url.Parse("/")
// Generate the static file // Generate the static file
ctx := middleware.Context{Root: md.FileSys} ctx := middleware.Context{Root: md.FileSys, Req: req, URL: urlVar}
_, err = md.Process(cfg, reqPath, body, ctx) _, err = md.Process(cfg, reqPath, body, ctx)
if err != nil { if err != nil {
return err return err

View file

@ -0,0 +1,111 @@
package rewrite
import (
"fmt"
"net/http"
"regexp"
"strings"
"github.com/mholt/caddy/middleware"
)
const (
// Operators
Is = "is"
Not = "not"
Has = "has"
StartsWith = "starts_with"
EndsWith = "ends_with"
Match = "match"
)
func operatorError(operator string) error {
return fmt.Errorf("Invalid operator %v", operator)
}
func newReplacer(r *http.Request) middleware.Replacer {
return middleware.NewReplacer(r, nil, "")
}
// condition is a rewrite condition.
type condition func(string, string) bool
var conditions = map[string]condition{
Is: isFunc,
Not: notFunc,
Has: hasFunc,
StartsWith: startsWithFunc,
EndsWith: endsWithFunc,
Match: matchFunc,
}
// isFunc is condition for Is operator.
// It checks for equality.
func isFunc(a, b string) bool {
return a == b
}
// notFunc is condition for Not operator.
// It checks for inequality.
func notFunc(a, b string) bool {
return a != b
}
// hasFunc is condition for Has operator.
// It checks if b is a substring of a.
func hasFunc(a, b string) bool {
return strings.Contains(a, b)
}
// startsWithFunc is condition for StartsWith operator.
// It checks if b is a prefix of a.
func startsWithFunc(a, b string) bool {
return strings.HasPrefix(a, b)
}
// endsWithFunc is condition for EndsWith operator.
// It checks if b is a suffix of a.
func endsWithFunc(a, b string) bool {
return strings.HasSuffix(a, b)
}
// matchFunc is condition for Match operator.
// It does regexp matching of a against pattern in b
func matchFunc(a, b string) bool {
matched, _ := regexp.MatchString(b, a)
return matched
}
// If is statement for a rewrite condition.
type If struct {
A string
Operator string
B string
}
// True returns true if the condition is true and false otherwise.
// If r is not nil, it replaces placeholders before comparison.
func (i If) True(r *http.Request) bool {
if c, ok := conditions[i.Operator]; ok {
a, b := i.A, i.B
if r != nil {
replacer := newReplacer(r)
a = replacer.Replace(i.A)
b = replacer.Replace(i.B)
}
return c(a, b)
}
return false
}
// NewIf creates a new If condition.
func NewIf(a, operator, b string) (If, error) {
if _, ok := conditions[operator]; !ok {
return If{}, operatorError(operator)
}
return If{
A: a,
Operator: operator,
B: b,
}, nil
}

View file

@ -0,0 +1,90 @@
package rewrite
import (
"net/http"
"strings"
"testing"
)
func TestConditions(t *testing.T) {
tests := []struct {
condition string
isTrue bool
}{
{"a is b", false},
{"a is a", true},
{"a not b", true},
{"a not a", false},
{"a has a", true},
{"a has b", false},
{"ba has b", true},
{"bab has b", true},
{"bab has bb", false},
{"bab starts_with bb", false},
{"bab starts_with ba", true},
{"bab starts_with bab", true},
{"bab ends_with bb", false},
{"bab ends_with bab", true},
{"bab ends_with ab", true},
{"a match *", false},
{"a match a", true},
{"a match .*", true},
{"a match a.*", true},
{"a match b.*", false},
{"ba match b.*", true},
{"ba match b[a-z]", true},
{"b0 match b[a-z]", false},
{"b0a match b[a-z]", false},
{"b0a match b[a-z]+", false},
{"b0a match b[a-z0-9]+", true},
}
for i, test := range tests {
str := strings.Fields(test.condition)
ifCond, err := NewIf(str[0], str[1], str[2])
if err != nil {
t.Error(err)
}
isTrue := ifCond.True(nil)
if isTrue != test.isTrue {
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
}
}
invalidOperators := []string{"ss", "and", "if"}
for _, op := range invalidOperators {
_, err := NewIf("a", op, "b")
if err == nil {
t.Errorf("Invalid operator %v used, expected error.", op)
}
}
replaceTests := []struct {
url string
condition string
isTrue bool
}{
{"/home", "{uri} match /home", true},
{"/hom", "{uri} match /home", false},
{"/hom", "{uri} starts_with /home", false},
{"/hom", "{uri} starts_with /h", true},
{"/home/.hiddenfile", `{uri} match \/\.(.*)`, true},
{"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true},
}
for i, test := range replaceTests {
r, err := http.NewRequest("GET", test.url, nil)
if err != nil {
t.Error(err)
}
str := strings.Fields(test.condition)
ifCond, err := NewIf(str[0], str[1], str[2])
if err != nil {
t.Error(err)
}
isTrue := ifCond.True(r)
if isTrue != test.isTrue {
t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
}
}
}

View file

@ -5,7 +5,6 @@ package rewrite
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"path" "path"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -16,14 +15,15 @@ import (
// Rewrite is middleware to rewrite request locations internally before being handled. // Rewrite is middleware to rewrite request locations internally before being handled.
type Rewrite struct { type Rewrite struct {
Next middleware.Handler Next middleware.Handler
Rules []Rule FileSys http.FileSystem
Rules []Rule
} }
// ServeHTTP implements the middleware.Handler interface. // ServeHTTP implements the middleware.Handler interface.
func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
for _, rule := range rw.Rules { for _, rule := range rw.Rules {
if ok := rule.Rewrite(r); ok { if ok := rule.Rewrite(rw.FileSys, r); ok {
break break
} }
} }
@ -33,7 +33,7 @@ func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
// Rule describes an internal location rewrite rule. // Rule describes an internal location rewrite rule.
type Rule interface { type Rule interface {
// Rewrite rewrites the internal location of the current request. // Rewrite rewrites the internal location of the current request.
Rewrite(*http.Request) bool Rewrite(http.FileSystem, *http.Request) bool
} }
// SimpleRule is a simple rewrite rule. // SimpleRule is a simple rewrite rule.
@ -47,23 +47,20 @@ func NewSimpleRule(from, to string) SimpleRule {
} }
// Rewrite rewrites the internal location of the current request. // Rewrite rewrites the internal location of the current request.
func (s SimpleRule) Rewrite(r *http.Request) bool { func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) bool {
if s.From == r.URL.Path { if s.From == r.URL.Path {
// take note of this rewrite for internal use by fastcgi // take note of this rewrite for internal use by fastcgi
// all we need is the URI, not full URL // all we need is the URI, not full URL
r.Header.Set(headerFieldName, r.URL.RequestURI()) r.Header.Set(headerFieldName, r.URL.RequestURI())
// replace variables // attempt rewrite
to := path.Clean(middleware.NewReplacer(r, nil, "").Replace(s.To)) return To(fs, r, s.To)
r.URL.Path = to
return true
} }
return false return false
} }
// RegexpRule is a rewrite rule based on a regular expression // ComplexRule is a rewrite rule based on a regular expression
type RegexpRule struct { type ComplexRule struct {
// Path base. Request to this path and subpaths will be rewritten // Path base. Request to this path and subpaths will be rewritten
Base string Base string
@ -73,18 +70,26 @@ type RegexpRule struct {
// Extensions to filter by // Extensions to filter by
Exts []string Exts []string
// Rewrite conditions
Ifs []If
*regexp.Regexp *regexp.Regexp
} }
// NewRegexpRule creates a new RegexpRule. It returns an error if regexp // NewRegexpRule creates a new RegexpRule. It returns an error if regexp
// pattern (pattern) or extensions (ext) are invalid. // pattern (pattern) or extensions (ext) are invalid.
func NewRegexpRule(base, pattern, to string, ext []string) (*RegexpRule, error) { func NewComplexRule(base, pattern, to string, ext []string, ifs []If) (*ComplexRule, error) {
r, err := regexp.Compile(pattern) // validate regexp if present
if err != nil { var r *regexp.Regexp
return nil, err if pattern != "" {
var err error
r, err = regexp.Compile(pattern)
if err != nil {
return nil, err
}
} }
// validate extensions // validate extensions if present
for _, v := range ext { for _, v := range ext {
if len(v) < 2 || (len(v) < 3 && v[0] == '!') { if len(v) < 2 || (len(v) < 3 && v[0] == '!') {
// check if no extension is specified // check if no extension is specified
@ -94,16 +99,17 @@ func NewRegexpRule(base, pattern, to string, ext []string) (*RegexpRule, error)
} }
} }
return &RegexpRule{ return &ComplexRule{
base, Base: base,
to, To: to,
ext, Exts: ext,
r, Ifs: ifs,
Regexp: r,
}, nil }, nil
} }
// Rewrite rewrites the internal location of the current request. // Rewrite rewrites the internal location of the current request.
func (r *RegexpRule) Rewrite(req *http.Request) bool { func (r *ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) bool {
rPath := req.URL.Path rPath := req.URL.Path
// validate base // validate base
@ -122,36 +128,27 @@ func (r *RegexpRule) Rewrite(req *http.Request) bool {
start-- start--
} }
// validate regexp // validate regexp if present
if !r.MatchString(rPath[start:]) { if r.Regexp != nil {
return false if !r.MatchString(rPath[start:]) {
return false
}
} }
// replace variables // validate rewrite conditions
to := path.Clean(middleware.NewReplacer(req, nil, "").Replace(r.To)) for _, i := range r.Ifs {
if !i.True(req) {
// validate resulting path return false
url, err := url.Parse(to) }
if err != nil {
return false
} }
// take note of this rewrite for internal use by fastcgi // attempt rewrite
// all we need is the URI, not full URL return To(fs, req, r.To)
req.Header.Set(headerFieldName, req.URL.RequestURI())
// perform rewrite
req.URL.Path = url.Path
if url.RawQuery != "" {
// overwrite query string if present
req.URL.RawQuery = url.RawQuery
}
return true
} }
// matchExt matches rPath against registered file extensions. // matchExt matches rPath against registered file extensions.
// Returns true if a match is found and false otherwise. // Returns true if a match is found and false otherwise.
func (r *RegexpRule) matchExt(rPath string) bool { func (r *ComplexRule) matchExt(rPath string) bool {
f := filepath.Base(rPath) f := filepath.Base(rPath)
ext := path.Ext(f) ext := path.Ext(f)
if ext == "" { if ext == "" {

View file

@ -4,9 +4,8 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing"
"strings" "strings"
"testing"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
) )
@ -19,9 +18,10 @@ func TestRewrite(t *testing.T) {
NewSimpleRule("/a", "/b"), NewSimpleRule("/a", "/b"),
NewSimpleRule("/b", "/b{uri}"), NewSimpleRule("/b", "/b{uri}"),
}, },
FileSys: http.Dir("."),
} }
regexpRules := [][]string{ regexps := [][]string{
{"/reg/", ".*", "/to", ""}, {"/reg/", ".*", "/to", ""},
{"/r/", "[a-z]+", "/toaz", "!.html|"}, {"/r/", "[a-z]+", "/toaz", "!.html|"},
{"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""}, {"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""},
@ -33,12 +33,12 @@ func TestRewrite(t *testing.T) {
{"/ab/", `.*\.jpg`, "/ajpg", ""}, {"/ab/", `.*\.jpg`, "/ajpg", ""},
} }
for _, regexpRule := range regexpRules { for _, regexpRule := range regexps {
var ext []string var ext []string
if s := strings.Split(regexpRule[3], "|"); len(s) > 1 { if s := strings.Split(regexpRule[3], "|"); len(s) > 1 {
ext = s[:len(s)-1] ext = s[:len(s)-1]
} }
rule, err := NewRegexpRule(regexpRule[0], regexpRule[1], regexpRule[2], ext) rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], ext, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

1
middleware/rewrite/testdata/testfile vendored Normal file
View file

@ -0,0 +1 @@
empty

86
middleware/rewrite/to.go Normal file
View file

@ -0,0 +1,86 @@
package rewrite
import (
"log"
"net/http"
"net/url"
"path"
"strings"
)
// To attempts rewrite. It attempts to rewrite to first valid path
// or the last path if none of the paths are valid.
// Returns true if rewrite is successful and false otherwise.
func To(fs http.FileSystem, r *http.Request, to string) bool {
tos := strings.Fields(to)
replacer := newReplacer(r)
// try each rewrite paths
t := ""
for _, v := range tos {
t = path.Clean(replacer.Replace(v))
// add trailing slash for directories, if present
if strings.HasSuffix(v, "/") && !strings.HasSuffix(t, "/") {
t += "/"
}
// validate file
if isValidFile(fs, t) {
break
}
}
// validate resulting path
u, err := url.Parse(t)
if err != nil {
// Let the user know we got here. Rewrite is expected but
// the resulting url is invalid.
log.Printf("[ERROR] rewrite: resulting path '%v' is invalid. error: %v", t, err)
return false
}
// take note of this rewrite for internal use by fastcgi
// all we need is the URI, not full URL
r.Header.Set(headerFieldName, r.URL.RequestURI())
// perform rewrite
r.URL.Path = u.Path
if u.RawQuery != "" {
// overwrite query string if present
r.URL.RawQuery = u.RawQuery
}
if u.Fragment != "" {
// overwrite fragment if present
r.URL.Fragment = u.Fragment
}
return true
}
// isValidFile checks if file exists on the filesystem.
// if file ends with `/`, it is validated as a directory.
func isValidFile(fs http.FileSystem, file string) bool {
if fs == nil {
return false
}
f, err := fs.Open(file)
if err != nil {
return false
}
defer f.Close()
stat, err := f.Stat()
if err != nil {
return false
}
// directory
if strings.HasSuffix(file, "/") {
return stat.IsDir()
}
// file
return !stat.IsDir()
}

View file

@ -0,0 +1,44 @@
package rewrite
import (
"net/http"
"net/url"
"testing"
)
func TestTo(t *testing.T) {
fs := http.Dir("testdata")
tests := []struct {
url string
to string
expected string
}{
{"/", "/somefiles", "/somefiles"},
{"/somefiles", "/somefiles /index.php{uri}", "/index.php/somefiles"},
{"/somefiles", "/testfile /index.php{uri}", "/testfile"},
{"/somefiles", "/testfile/ /index.php{uri}", "/index.php/somefiles"},
{"/somefiles", "/somefiles /index.php{uri}", "/index.php/somefiles"},
{"/?a=b", "/somefiles /index.php?{query}", "/index.php?a=b"},
{"/?a=b", "/testfile /index.php?{query}", "/testfile?a=b"},
{"/?a=b", "/testdir /index.php?{query}", "/index.php?a=b"},
{"/?a=b", "/testdir/ /index.php?{query}", "/testdir/?a=b"},
}
uri := func(r *url.URL) string {
uri := r.Path
if r.RawQuery != "" {
uri += "?" + r.RawQuery
}
return uri
}
for i, test := range tests {
r, err := http.NewRequest("GET", test.url, nil)
if err != nil {
t.Error(err)
}
To(fs, r, test.to)
if uri(r.URL) != test.expected {
t.Errorf("Test %v: expected %v found %v", i, test.expected, uri(r.URL))
}
}
}