diff --git a/caddy/setup/rewrite.go b/caddy/setup/rewrite.go index badbaaee..ab997d27 100644 --- a/caddy/setup/rewrite.go +++ b/caddy/setup/rewrite.go @@ -2,6 +2,7 @@ package setup import ( "net/http" + "strconv" "strings" "github.com/mholt/caddy/middleware" @@ -33,6 +34,7 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) { var err error var base = "/" var pattern, to string + var status int var ext []string args := c.RemainingArgs() @@ -73,15 +75,23 @@ func rewriteParse(c *Controller) ([]rewrite.Rule, error) { return nil, err } ifs = append(ifs, ifCond) + case "status": + if !c.NextArg() { + return nil, c.ArgErr() + } + status, _ = strconv.Atoi(c.Val()) + if status < 400 || status > 499 { + return nil, c.Err("status must be 4xx") + } default: return nil, c.ArgErr() } } - // ensure to is specified - if to == "" { + // ensure to or status is specified + if to == "" && status == 0 { return nil, c.ArgErr() } - if rule, err = rewrite.NewComplexRule(base, pattern, to, ext, ifs); err != nil { + if rule, err = rewrite.NewComplexRule(base, pattern, to, status, ext, ifs); err != nil { return nil, err } regexpRules = append(regexpRules, rule) diff --git a/caddy/setup/rewrite_test.go b/caddy/setup/rewrite_test.go index c0dd2fb9..224ab643 100644 --- a/caddy/setup/rewrite_test.go +++ b/caddy/setup/rewrite_test.go @@ -137,6 +137,33 @@ func TestRewriteParse(t *testing.T) { }`, false, []rewrite.Rule{ &rewrite.ComplexRule{Base: "/", To: "/to", Ifs: []rewrite.If{rewrite.If{A: "{path}", Operator: "is", B: "a"}}}, }}, + {`rewrite { + status 400 + }`, false, []rewrite.Rule{ + &rewrite.ComplexRule{Base: "/", Regexp: regexp.MustCompile(".*"), Status: 400}, + }}, + {`rewrite { + to /to + status 400 + }`, false, []rewrite.Rule{ + &rewrite.ComplexRule{Base: "/", To: "/to", Regexp: regexp.MustCompile(".*"), Status: 400}, + }}, + {`rewrite { + status 399 + }`, true, []rewrite.Rule{ + &rewrite.ComplexRule{}, + }}, + {`rewrite { + status 0 + }`, true, []rewrite.Rule{ + &rewrite.ComplexRule{}, + }}, + {`rewrite { + to /to + status 0 + }`, true, []rewrite.Rule{ + &rewrite.ComplexRule{}, + }}, } for i, test := range regexpTests { diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go index 60cc0b9d..30f5fc49 100644 --- a/middleware/rewrite/rewrite.go +++ b/middleware/rewrite/rewrite.go @@ -24,6 +24,13 @@ type Rewrite struct { func (rw Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { for _, rule := range rw.Rules { if ok := rule.Rewrite(rw.FileSys, r); ok { + + // if rule is complex rule and status code is set + if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 { + return cRule.Status, nil + } + + // rewrite done break } } @@ -67,6 +74,10 @@ type ComplexRule struct { // Path to rewrite to To string + // If set, neither performs rewrite nor proceeds + // with request. Only returns code. + Status int + // Extensions to filter by Exts []string @@ -78,7 +89,7 @@ type ComplexRule struct { // NewRegexpRule creates a new RegexpRule. It returns an error if regexp // pattern (pattern) or extensions (ext) are invalid. -func NewComplexRule(base, pattern, to string, ext []string, ifs []If) (*ComplexRule, error) { +func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) { // validate regexp if present var r *regexp.Regexp if pattern != "" { @@ -102,6 +113,7 @@ func NewComplexRule(base, pattern, to string, ext []string, ifs []If) (*ComplexR return &ComplexRule{ Base: base, To: to, + Status: status, Exts: ext, Ifs: ifs, Regexp: r, diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go index 5b891606..7f39a56b 100644 --- a/middleware/rewrite/rewrite_test.go +++ b/middleware/rewrite/rewrite_test.go @@ -41,7 +41,7 @@ func TestRewrite(t *testing.T) { if s := strings.Split(regexpRule[3], "|"); len(s) > 1 { ext = s[:len(s)-1] } - rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], ext, nil) + rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], 0, ext, nil) if err != nil { t.Fatal(err) } @@ -106,6 +106,35 @@ func TestRewrite(t *testing.T) { i, test.expectedTo, rec.Body.String()) } } + + statusTests := []int{ + 401, 405, 403, 400, + } + + for i, s := range statusTests { + urlPath := fmt.Sprintf("/status%d", i) + rule, err := NewComplexRule(urlPath, "", "", s, nil, nil) + if err != nil { + t.Fatalf("Test %d: No error expected for rule but found %v", i, err) + } + rw.Rules = append(rw.Rules, rule) + req, err := http.NewRequest("GET", urlPath, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) + } + + rec := httptest.NewRecorder() + code, err := rw.ServeHTTP(rec, req) + if err != nil { + t.Fatalf("Test %d: No error expected for handler but found %v", i, err) + } + if rec.Body.String() != "" { + t.Errorf("Test %d: Expected empty body but found %s", i, rec.Body.String()) + } + if code != s { + t.Errorf("Text %d: Expected status code %d found %d", i, s, code) + } + } } func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {