Implement rewrite middleware; fix middleware stack bugs

This commit is contained in:
Matthew Holt 2019-05-20 23:48:43 -06:00
parent b84cb05848
commit 65195a726d
8 changed files with 133 additions and 34 deletions

View file

@ -9,6 +9,7 @@ import (
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/fileserver"
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/headers"
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/reverseproxy"
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/rewrite"
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddytls"
)

View file

@ -104,7 +104,7 @@ func (fsrv *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) error
if filename == "" {
// no files worked, so resort to fallback
if fsrv.Fallback != nil {
fallback := fsrv.Fallback.BuildCompositeRoute(w, r)
fallback, w := fsrv.Fallback.BuildCompositeRoute(w, r)
return fallback.ServeHTTP(w, r)
}
return caddyhttp.Error(http.StatusNotFound, nil)

View file

@ -37,29 +37,36 @@ type RespHeaderOps struct {
}
func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
apply(h.Request, r.Header)
repl := r.Context().Value(caddy2.ReplacerCtxKey).(caddy2.Replacer)
apply(h.Request, r.Header, repl)
if h.Response.Deferred {
w = &responseWriterWrapper{
ResponseWriterWrapper: &caddyhttp.ResponseWriterWrapper{ResponseWriter: w},
replacer: repl,
headerOps: h.Response.HeaderOps,
}
} else {
apply(h.Response.HeaderOps, w.Header())
apply(h.Response.HeaderOps, w.Header(), repl)
}
return next.ServeHTTP(w, r)
}
func apply(ops HeaderOps, hdr http.Header) {
func apply(ops HeaderOps, hdr http.Header, repl caddy2.Replacer) {
for fieldName, vals := range ops.Add {
fieldName = repl.ReplaceAll(fieldName, "")
for _, v := range vals {
hdr.Add(fieldName, v)
hdr.Add(fieldName, repl.ReplaceAll(v, ""))
}
}
for fieldName, vals := range ops.Set {
fieldName = repl.ReplaceAll(fieldName, "")
for i := range vals {
vals[i] = repl.ReplaceAll(vals[i], "")
}
hdr.Set(fieldName, strings.Join(vals, ","))
}
for _, fieldName := range ops.Delete {
hdr.Del(fieldName)
hdr.Del(repl.ReplaceAll(fieldName, ""))
}
}
@ -67,6 +74,7 @@ func apply(ops HeaderOps, hdr http.Header) {
// operations until WriteHeader is called.
type responseWriterWrapper struct {
*caddyhttp.ResponseWriterWrapper
replacer caddy2.Replacer
headerOps HeaderOps
wroteHeader bool
}
@ -83,7 +91,7 @@ func (rww *responseWriterWrapper) WriteHeader(status int) {
return
}
rww.wroteHeader = true
apply(rww.headerOps, rww.ResponseWriterWrapper.Header())
apply(rww.headerOps, rww.ResponseWriterWrapper.Header(), rww.replacer)
rww.ResponseWriterWrapper.WriteHeader(status)
}

View file

@ -227,9 +227,10 @@ func TestPathREMatcher(t *testing.T) {
// set up the fake request and its Replacer
req := &http.Request{URL: &url.URL{Path: tc.input}}
repl := newReplacer(req, httptest.NewRecorder())
repl := caddy2.NewReplacer()
ctx := context.WithValue(req.Context(), caddy2.ReplacerCtxKey, repl)
req = req.WithContext(ctx)
addHTTPVarsToReplacer(repl, req, httptest.NewRecorder())
actual := tc.match.Match(req)
if actual != tc.expect {
@ -344,9 +345,10 @@ func TestHeaderREMatcher(t *testing.T) {
// set up the fake request and its Replacer
req := &http.Request{Header: tc.input, URL: new(url.URL)}
repl := newReplacer(req, httptest.NewRecorder())
repl := caddy2.NewReplacer()
ctx := context.WithValue(req.Context(), caddy2.ReplacerCtxKey, repl)
req = req.WithContext(ctx)
addHTTPVarsToReplacer(repl, req, httptest.NewRecorder())
actual := tc.match.Match(req)
if actual != tc.expect {

View file

@ -13,9 +13,7 @@ import (
// TODO: A simple way to format or escape or encode each value would be nice
// ... TODO: Should we just use templates? :-/ yeesh...
func newReplacer(req *http.Request, w http.ResponseWriter) caddy2.Replacer {
repl := caddy2.NewReplacer()
func addHTTPVarsToReplacer(repl caddy2.Replacer, req *http.Request, w http.ResponseWriter) {
httpVars := func() map[string]string {
m := make(map[string]string)
if req != nil {
@ -78,6 +76,4 @@ func newReplacer(req *http.Request, w http.ResponseWriter) caddy2.Replacer {
}
repl.Map(httpVars)
return repl
}

View file

@ -0,0 +1,71 @@
package headers
import (
"net/http"
"net/url"
"strings"
"bitbucket.org/lightcodelabs/caddy2"
"bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp"
)
func init() {
caddy2.RegisterModule(caddy2.Module{
Name: "http.middleware.rewrite",
New: func() (interface{}, error) { return new(Rewrite), nil },
})
}
// Rewrite is a middleware which can rewrite HTTP requests.
type Rewrite struct {
Method string `json:"method"`
URI string `json:"uri"`
Rehandle bool `json:"rehandle"`
}
func (rewr Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
repl := r.Context().Value(caddy2.ReplacerCtxKey).(caddy2.Replacer)
var rehandleNeeded bool
if rewr.Method != "" {
method := r.Method
r.Method = strings.ToUpper(repl.ReplaceAll(rewr.Method, ""))
if r.Method != method {
rehandleNeeded = true
}
}
if rewr.URI != "" {
// TODO: clean this all up, I don't think it's right
oldURI := r.RequestURI
newURI := repl.ReplaceAll(rewr.URI, "")
u, err := url.Parse(newURI)
if err != nil {
return caddyhttp.Error(http.StatusInternalServerError, err)
}
r.RequestURI = newURI
r.URL.Path = u.Path
if u.RawQuery != "" {
r.URL.RawQuery = u.RawQuery
}
if u.Fragment != "" {
r.URL.Fragment = u.Fragment
}
if newURI != oldURI {
rehandleNeeded = true
}
}
if rehandleNeeded && rewr.Rehandle {
return caddyhttp.ErrRehandle
}
return next.ServeHTTP(w, r)
}
// Interface guard
var _ caddyhttp.MiddlewareHandler = (*Rewrite)(nil)

View file

@ -65,23 +65,24 @@ func (routes RouteList) Provision(ctx caddy2.Context) error {
return nil
}
// BuildCompositeRoute creates a chain of handlers by
// applying all the matching routes.
func (routes RouteList) BuildCompositeRoute(w http.ResponseWriter, r *http.Request) Handler {
// BuildCompositeRoute creates a chain of handlers by applying all the matching
// routes. The returned ResponseWriter should be used instead of rw.
func (routes RouteList) BuildCompositeRoute(rw http.ResponseWriter, req *http.Request) (Handler, http.ResponseWriter) {
mrw := &middlewareResponseWriter{ResponseWriterWrapper: &ResponseWriterWrapper{rw}}
if len(routes) == 0 {
return emptyHandler
return emptyHandler, mrw
}
var mid []Middleware
var responder Handler
mrw := &middlewareResponseWriter{ResponseWriterWrapper: &ResponseWriterWrapper{w}}
groups := make(map[string]struct{})
routeLoop:
for _, route := range routes {
// see if route matches
for _, m := range route.matchers {
if !m.Match(r) {
if !m.Match(req) {
continue routeLoop
}
}
@ -102,15 +103,13 @@ routeLoop:
// apply the rest of the route
for _, m := range route.middleware {
mid = append(mid, func(next HandlerFunc) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) error {
// TODO: This is where request tracing could be implemented; also
// see below to trace the responder as well
// TODO: Trace a diff of the request, would be cool too! see what changed since the last middleware (host, headers, URI...)
// TODO: see what the std lib gives us in terms of stack trracing too
return m.ServeHTTP(mrw, r, next)
}
})
// we have to be sure to wrap m outside
// of our current scope so that the
// reference to this m isn't overwritten
// on the next iteration, leaving only
// the last middleware in the chain as
// the ONLY middleware in the chain!
mid = append(mid, wrapMiddleware(m))
}
if responder == nil {
responder = route.responder
@ -132,7 +131,25 @@ routeLoop:
stack = mid[i](stack)
}
return stack
return stack, mrw
}
// wrapMiddleware wraps m such that it can be correctly
// appended to a list of middleware. This is necessary
// so that only the last middleware in a loop does not
// become the only middleware of the stack, repeatedly
// executed (i.e. it is necessary to keep a reference
// to this m outside of the scope of a loop)!
func wrapMiddleware(m MiddlewareHandler) Middleware {
return func(next HandlerFunc) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) error {
// TODO: This is where request tracing could be implemented; also
// see below to trace the responder as well
// TODO: Trace a diff of the request, would be cool too! see what changed since the last middleware (host, headers, URI...)
// TODO: see what the std lib gives us in terms of stack tracing too
return m.ServeHTTP(w, r, next)
}
}
}
type middlewareResponseWriter struct {

View file

@ -32,14 +32,18 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// set up the replacer
repl := newReplacer(r, w)
// set up the context for the request
repl := caddy2.NewReplacer()
ctx := context.WithValue(r.Context(), caddy2.ReplacerCtxKey, repl)
ctx = context.WithValue(ctx, TableCtxKey, make(map[string]interface{})) // TODO: Implement this
r = r.WithContext(ctx)
// once the pointer to the request won't change
// anymore, finish setting up the replacer
addHTTPVarsToReplacer(repl, r, w)
// build and execute the main handler chain
stack := s.Routes.BuildCompositeRoute(w, r)
stack, w := s.Routes.BuildCompositeRoute(w, r)
err := s.executeCompositeRoute(w, r, stack)
if err != nil {
// add the raw error value to the request context
@ -64,7 +68,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(handlerErr.StatusCode)
}
} else {
errStack := s.Errors.Routes.BuildCompositeRoute(w, r)
errStack, w := s.Errors.Routes.BuildCompositeRoute(w, r)
err := s.executeCompositeRoute(w, r, errStack)
if err != nil {
// TODO: what should we do if the error handler has an error?