Implement rewrite middleware; fix middleware stack bugs

This commit is contained in:
Matthew Holt 2019-05-20 23:48:43 -06:00
parent b84cb05848
commit 65195a726d
8 changed files with 133 additions and 34 deletions

View file

@ -9,6 +9,7 @@ import (
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/fileserver" _ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/fileserver"
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/headers" _ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/headers"
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/reverseproxy" _ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/reverseproxy"
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/rewrite"
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddytls" _ "bitbucket.org/lightcodelabs/caddy2/modules/caddytls"
) )

View file

@ -104,7 +104,7 @@ func (fsrv *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) error
if filename == "" { if filename == "" {
// no files worked, so resort to fallback // no files worked, so resort to fallback
if fsrv.Fallback != nil { if fsrv.Fallback != nil {
fallback := fsrv.Fallback.BuildCompositeRoute(w, r) fallback, w := fsrv.Fallback.BuildCompositeRoute(w, r)
return fallback.ServeHTTP(w, r) return fallback.ServeHTTP(w, r)
} }
return caddyhttp.Error(http.StatusNotFound, nil) return caddyhttp.Error(http.StatusNotFound, nil)

View file

@ -37,29 +37,36 @@ type RespHeaderOps struct {
} }
func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
apply(h.Request, r.Header) repl := r.Context().Value(caddy2.ReplacerCtxKey).(caddy2.Replacer)
apply(h.Request, r.Header, repl)
if h.Response.Deferred { if h.Response.Deferred {
w = &responseWriterWrapper{ w = &responseWriterWrapper{
ResponseWriterWrapper: &caddyhttp.ResponseWriterWrapper{ResponseWriter: w}, ResponseWriterWrapper: &caddyhttp.ResponseWriterWrapper{ResponseWriter: w},
replacer: repl,
headerOps: h.Response.HeaderOps, headerOps: h.Response.HeaderOps,
} }
} else { } else {
apply(h.Response.HeaderOps, w.Header()) apply(h.Response.HeaderOps, w.Header(), repl)
} }
return next.ServeHTTP(w, r) return next.ServeHTTP(w, r)
} }
func apply(ops HeaderOps, hdr http.Header) { func apply(ops HeaderOps, hdr http.Header, repl caddy2.Replacer) {
for fieldName, vals := range ops.Add { for fieldName, vals := range ops.Add {
fieldName = repl.ReplaceAll(fieldName, "")
for _, v := range vals { for _, v := range vals {
hdr.Add(fieldName, v) hdr.Add(fieldName, repl.ReplaceAll(v, ""))
} }
} }
for fieldName, vals := range ops.Set { for fieldName, vals := range ops.Set {
fieldName = repl.ReplaceAll(fieldName, "")
for i := range vals {
vals[i] = repl.ReplaceAll(vals[i], "")
}
hdr.Set(fieldName, strings.Join(vals, ",")) hdr.Set(fieldName, strings.Join(vals, ","))
} }
for _, fieldName := range ops.Delete { for _, fieldName := range ops.Delete {
hdr.Del(fieldName) hdr.Del(repl.ReplaceAll(fieldName, ""))
} }
} }
@ -67,6 +74,7 @@ func apply(ops HeaderOps, hdr http.Header) {
// operations until WriteHeader is called. // operations until WriteHeader is called.
type responseWriterWrapper struct { type responseWriterWrapper struct {
*caddyhttp.ResponseWriterWrapper *caddyhttp.ResponseWriterWrapper
replacer caddy2.Replacer
headerOps HeaderOps headerOps HeaderOps
wroteHeader bool wroteHeader bool
} }
@ -83,7 +91,7 @@ func (rww *responseWriterWrapper) WriteHeader(status int) {
return return
} }
rww.wroteHeader = true rww.wroteHeader = true
apply(rww.headerOps, rww.ResponseWriterWrapper.Header()) apply(rww.headerOps, rww.ResponseWriterWrapper.Header(), rww.replacer)
rww.ResponseWriterWrapper.WriteHeader(status) rww.ResponseWriterWrapper.WriteHeader(status)
} }

View file

@ -227,9 +227,10 @@ func TestPathREMatcher(t *testing.T) {
// set up the fake request and its Replacer // set up the fake request and its Replacer
req := &http.Request{URL: &url.URL{Path: tc.input}} req := &http.Request{URL: &url.URL{Path: tc.input}}
repl := newReplacer(req, httptest.NewRecorder()) repl := caddy2.NewReplacer()
ctx := context.WithValue(req.Context(), caddy2.ReplacerCtxKey, repl) ctx := context.WithValue(req.Context(), caddy2.ReplacerCtxKey, repl)
req = req.WithContext(ctx) req = req.WithContext(ctx)
addHTTPVarsToReplacer(repl, req, httptest.NewRecorder())
actual := tc.match.Match(req) actual := tc.match.Match(req)
if actual != tc.expect { if actual != tc.expect {
@ -344,9 +345,10 @@ func TestHeaderREMatcher(t *testing.T) {
// set up the fake request and its Replacer // set up the fake request and its Replacer
req := &http.Request{Header: tc.input, URL: new(url.URL)} req := &http.Request{Header: tc.input, URL: new(url.URL)}
repl := newReplacer(req, httptest.NewRecorder()) repl := caddy2.NewReplacer()
ctx := context.WithValue(req.Context(), caddy2.ReplacerCtxKey, repl) ctx := context.WithValue(req.Context(), caddy2.ReplacerCtxKey, repl)
req = req.WithContext(ctx) req = req.WithContext(ctx)
addHTTPVarsToReplacer(repl, req, httptest.NewRecorder())
actual := tc.match.Match(req) actual := tc.match.Match(req)
if actual != tc.expect { if actual != tc.expect {

View file

@ -13,9 +13,7 @@ import (
// TODO: A simple way to format or escape or encode each value would be nice // TODO: A simple way to format or escape or encode each value would be nice
// ... TODO: Should we just use templates? :-/ yeesh... // ... TODO: Should we just use templates? :-/ yeesh...
func newReplacer(req *http.Request, w http.ResponseWriter) caddy2.Replacer { func addHTTPVarsToReplacer(repl caddy2.Replacer, req *http.Request, w http.ResponseWriter) {
repl := caddy2.NewReplacer()
httpVars := func() map[string]string { httpVars := func() map[string]string {
m := make(map[string]string) m := make(map[string]string)
if req != nil { if req != nil {
@ -78,6 +76,4 @@ func newReplacer(req *http.Request, w http.ResponseWriter) caddy2.Replacer {
} }
repl.Map(httpVars) repl.Map(httpVars)
return repl
} }

View file

@ -0,0 +1,71 @@
package headers
import (
"net/http"
"net/url"
"strings"
"bitbucket.org/lightcodelabs/caddy2"
"bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp"
)
func init() {
caddy2.RegisterModule(caddy2.Module{
Name: "http.middleware.rewrite",
New: func() (interface{}, error) { return new(Rewrite), nil },
})
}
// Rewrite is a middleware which can rewrite HTTP requests.
type Rewrite struct {
Method string `json:"method"`
URI string `json:"uri"`
Rehandle bool `json:"rehandle"`
}
func (rewr Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
repl := r.Context().Value(caddy2.ReplacerCtxKey).(caddy2.Replacer)
var rehandleNeeded bool
if rewr.Method != "" {
method := r.Method
r.Method = strings.ToUpper(repl.ReplaceAll(rewr.Method, ""))
if r.Method != method {
rehandleNeeded = true
}
}
if rewr.URI != "" {
// TODO: clean this all up, I don't think it's right
oldURI := r.RequestURI
newURI := repl.ReplaceAll(rewr.URI, "")
u, err := url.Parse(newURI)
if err != nil {
return caddyhttp.Error(http.StatusInternalServerError, err)
}
r.RequestURI = newURI
r.URL.Path = u.Path
if u.RawQuery != "" {
r.URL.RawQuery = u.RawQuery
}
if u.Fragment != "" {
r.URL.Fragment = u.Fragment
}
if newURI != oldURI {
rehandleNeeded = true
}
}
if rehandleNeeded && rewr.Rehandle {
return caddyhttp.ErrRehandle
}
return next.ServeHTTP(w, r)
}
// Interface guard
var _ caddyhttp.MiddlewareHandler = (*Rewrite)(nil)

View file

@ -65,23 +65,24 @@ func (routes RouteList) Provision(ctx caddy2.Context) error {
return nil return nil
} }
// BuildCompositeRoute creates a chain of handlers by // BuildCompositeRoute creates a chain of handlers by applying all the matching
// applying all the matching routes. // routes. The returned ResponseWriter should be used instead of rw.
func (routes RouteList) BuildCompositeRoute(w http.ResponseWriter, r *http.Request) Handler { func (routes RouteList) BuildCompositeRoute(rw http.ResponseWriter, req *http.Request) (Handler, http.ResponseWriter) {
mrw := &middlewareResponseWriter{ResponseWriterWrapper: &ResponseWriterWrapper{rw}}
if len(routes) == 0 { if len(routes) == 0 {
return emptyHandler return emptyHandler, mrw
} }
var mid []Middleware var mid []Middleware
var responder Handler var responder Handler
mrw := &middlewareResponseWriter{ResponseWriterWrapper: &ResponseWriterWrapper{w}}
groups := make(map[string]struct{}) groups := make(map[string]struct{})
routeLoop: routeLoop:
for _, route := range routes { for _, route := range routes {
// see if route matches // see if route matches
for _, m := range route.matchers { for _, m := range route.matchers {
if !m.Match(r) { if !m.Match(req) {
continue routeLoop continue routeLoop
} }
} }
@ -102,15 +103,13 @@ routeLoop:
// apply the rest of the route // apply the rest of the route
for _, m := range route.middleware { for _, m := range route.middleware {
mid = append(mid, func(next HandlerFunc) HandlerFunc { // we have to be sure to wrap m outside
return func(w http.ResponseWriter, r *http.Request) error { // of our current scope so that the
// TODO: This is where request tracing could be implemented; also // reference to this m isn't overwritten
// see below to trace the responder as well // on the next iteration, leaving only
// TODO: Trace a diff of the request, would be cool too! see what changed since the last middleware (host, headers, URI...) // the last middleware in the chain as
// TODO: see what the std lib gives us in terms of stack trracing too // the ONLY middleware in the chain!
return m.ServeHTTP(mrw, r, next) mid = append(mid, wrapMiddleware(m))
}
})
} }
if responder == nil { if responder == nil {
responder = route.responder responder = route.responder
@ -132,7 +131,25 @@ routeLoop:
stack = mid[i](stack) stack = mid[i](stack)
} }
return stack return stack, mrw
}
// wrapMiddleware wraps m such that it can be correctly
// appended to a list of middleware. This is necessary
// so that only the last middleware in a loop does not
// become the only middleware of the stack, repeatedly
// executed (i.e. it is necessary to keep a reference
// to this m outside of the scope of a loop)!
func wrapMiddleware(m MiddlewareHandler) Middleware {
return func(next HandlerFunc) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) error {
// TODO: This is where request tracing could be implemented; also
// see below to trace the responder as well
// TODO: Trace a diff of the request, would be cool too! see what changed since the last middleware (host, headers, URI...)
// TODO: see what the std lib gives us in terms of stack tracing too
return m.ServeHTTP(w, r, next)
}
}
} }
type middlewareResponseWriter struct { type middlewareResponseWriter struct {

View file

@ -32,14 +32,18 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
// set up the replacer // set up the context for the request
repl := newReplacer(r, w) repl := caddy2.NewReplacer()
ctx := context.WithValue(r.Context(), caddy2.ReplacerCtxKey, repl) ctx := context.WithValue(r.Context(), caddy2.ReplacerCtxKey, repl)
ctx = context.WithValue(ctx, TableCtxKey, make(map[string]interface{})) // TODO: Implement this ctx = context.WithValue(ctx, TableCtxKey, make(map[string]interface{})) // TODO: Implement this
r = r.WithContext(ctx) r = r.WithContext(ctx)
// once the pointer to the request won't change
// anymore, finish setting up the replacer
addHTTPVarsToReplacer(repl, r, w)
// build and execute the main handler chain // build and execute the main handler chain
stack := s.Routes.BuildCompositeRoute(w, r) stack, w := s.Routes.BuildCompositeRoute(w, r)
err := s.executeCompositeRoute(w, r, stack) err := s.executeCompositeRoute(w, r, stack)
if err != nil { if err != nil {
// add the raw error value to the request context // add the raw error value to the request context
@ -64,7 +68,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(handlerErr.StatusCode) w.WriteHeader(handlerErr.StatusCode)
} }
} else { } else {
errStack := s.Errors.Routes.BuildCompositeRoute(w, r) errStack, w := s.Errors.Routes.BuildCompositeRoute(w, r)
err := s.executeCompositeRoute(w, r, errStack) err := s.executeCompositeRoute(w, r, errStack)
if err != nil { if err != nil {
// TODO: what should we do if the error handler has an error? // TODO: what should we do if the error handler has an error?