mirror of
https://github.com/caddyserver/caddy.git
synced 2024-12-26 13:43:47 +03:00
Implement rewrite middleware; fix middleware stack bugs
This commit is contained in:
parent
b84cb05848
commit
65195a726d
8 changed files with 133 additions and 34 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/fileserver"
|
||||
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/headers"
|
||||
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/reverseproxy"
|
||||
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp/rewrite"
|
||||
_ "bitbucket.org/lightcodelabs/caddy2/modules/caddytls"
|
||||
)
|
||||
|
||||
|
|
|
@ -104,7 +104,7 @@ func (fsrv *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) error
|
|||
if filename == "" {
|
||||
// no files worked, so resort to fallback
|
||||
if fsrv.Fallback != nil {
|
||||
fallback := fsrv.Fallback.BuildCompositeRoute(w, r)
|
||||
fallback, w := fsrv.Fallback.BuildCompositeRoute(w, r)
|
||||
return fallback.ServeHTTP(w, r)
|
||||
}
|
||||
return caddyhttp.Error(http.StatusNotFound, nil)
|
||||
|
|
|
@ -37,29 +37,36 @@ type RespHeaderOps struct {
|
|||
}
|
||||
|
||||
func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
|
||||
apply(h.Request, r.Header)
|
||||
repl := r.Context().Value(caddy2.ReplacerCtxKey).(caddy2.Replacer)
|
||||
apply(h.Request, r.Header, repl)
|
||||
if h.Response.Deferred {
|
||||
w = &responseWriterWrapper{
|
||||
ResponseWriterWrapper: &caddyhttp.ResponseWriterWrapper{ResponseWriter: w},
|
||||
replacer: repl,
|
||||
headerOps: h.Response.HeaderOps,
|
||||
}
|
||||
} else {
|
||||
apply(h.Response.HeaderOps, w.Header())
|
||||
apply(h.Response.HeaderOps, w.Header(), repl)
|
||||
}
|
||||
return next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func apply(ops HeaderOps, hdr http.Header) {
|
||||
func apply(ops HeaderOps, hdr http.Header, repl caddy2.Replacer) {
|
||||
for fieldName, vals := range ops.Add {
|
||||
fieldName = repl.ReplaceAll(fieldName, "")
|
||||
for _, v := range vals {
|
||||
hdr.Add(fieldName, v)
|
||||
hdr.Add(fieldName, repl.ReplaceAll(v, ""))
|
||||
}
|
||||
}
|
||||
for fieldName, vals := range ops.Set {
|
||||
fieldName = repl.ReplaceAll(fieldName, "")
|
||||
for i := range vals {
|
||||
vals[i] = repl.ReplaceAll(vals[i], "")
|
||||
}
|
||||
hdr.Set(fieldName, strings.Join(vals, ","))
|
||||
}
|
||||
for _, fieldName := range ops.Delete {
|
||||
hdr.Del(fieldName)
|
||||
hdr.Del(repl.ReplaceAll(fieldName, ""))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -67,6 +74,7 @@ func apply(ops HeaderOps, hdr http.Header) {
|
|||
// operations until WriteHeader is called.
|
||||
type responseWriterWrapper struct {
|
||||
*caddyhttp.ResponseWriterWrapper
|
||||
replacer caddy2.Replacer
|
||||
headerOps HeaderOps
|
||||
wroteHeader bool
|
||||
}
|
||||
|
@ -83,7 +91,7 @@ func (rww *responseWriterWrapper) WriteHeader(status int) {
|
|||
return
|
||||
}
|
||||
rww.wroteHeader = true
|
||||
apply(rww.headerOps, rww.ResponseWriterWrapper.Header())
|
||||
apply(rww.headerOps, rww.ResponseWriterWrapper.Header(), rww.replacer)
|
||||
rww.ResponseWriterWrapper.WriteHeader(status)
|
||||
}
|
||||
|
||||
|
|
|
@ -227,9 +227,10 @@ func TestPathREMatcher(t *testing.T) {
|
|||
|
||||
// set up the fake request and its Replacer
|
||||
req := &http.Request{URL: &url.URL{Path: tc.input}}
|
||||
repl := newReplacer(req, httptest.NewRecorder())
|
||||
repl := caddy2.NewReplacer()
|
||||
ctx := context.WithValue(req.Context(), caddy2.ReplacerCtxKey, repl)
|
||||
req = req.WithContext(ctx)
|
||||
addHTTPVarsToReplacer(repl, req, httptest.NewRecorder())
|
||||
|
||||
actual := tc.match.Match(req)
|
||||
if actual != tc.expect {
|
||||
|
@ -344,9 +345,10 @@ func TestHeaderREMatcher(t *testing.T) {
|
|||
|
||||
// set up the fake request and its Replacer
|
||||
req := &http.Request{Header: tc.input, URL: new(url.URL)}
|
||||
repl := newReplacer(req, httptest.NewRecorder())
|
||||
repl := caddy2.NewReplacer()
|
||||
ctx := context.WithValue(req.Context(), caddy2.ReplacerCtxKey, repl)
|
||||
req = req.WithContext(ctx)
|
||||
addHTTPVarsToReplacer(repl, req, httptest.NewRecorder())
|
||||
|
||||
actual := tc.match.Match(req)
|
||||
if actual != tc.expect {
|
||||
|
|
|
@ -13,9 +13,7 @@ import (
|
|||
// TODO: A simple way to format or escape or encode each value would be nice
|
||||
// ... TODO: Should we just use templates? :-/ yeesh...
|
||||
|
||||
func newReplacer(req *http.Request, w http.ResponseWriter) caddy2.Replacer {
|
||||
repl := caddy2.NewReplacer()
|
||||
|
||||
func addHTTPVarsToReplacer(repl caddy2.Replacer, req *http.Request, w http.ResponseWriter) {
|
||||
httpVars := func() map[string]string {
|
||||
m := make(map[string]string)
|
||||
if req != nil {
|
||||
|
@ -78,6 +76,4 @@ func newReplacer(req *http.Request, w http.ResponseWriter) caddy2.Replacer {
|
|||
}
|
||||
|
||||
repl.Map(httpVars)
|
||||
|
||||
return repl
|
||||
}
|
||||
|
|
71
modules/caddyhttp/rewrite/rewrite.go
Normal file
71
modules/caddyhttp/rewrite/rewrite.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package headers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"bitbucket.org/lightcodelabs/caddy2"
|
||||
"bitbucket.org/lightcodelabs/caddy2/modules/caddyhttp"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy2.RegisterModule(caddy2.Module{
|
||||
Name: "http.middleware.rewrite",
|
||||
New: func() (interface{}, error) { return new(Rewrite), nil },
|
||||
})
|
||||
}
|
||||
|
||||
// Rewrite is a middleware which can rewrite HTTP requests.
|
||||
type Rewrite struct {
|
||||
Method string `json:"method"`
|
||||
URI string `json:"uri"`
|
||||
Rehandle bool `json:"rehandle"`
|
||||
}
|
||||
|
||||
func (rewr Rewrite) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
|
||||
repl := r.Context().Value(caddy2.ReplacerCtxKey).(caddy2.Replacer)
|
||||
var rehandleNeeded bool
|
||||
|
||||
if rewr.Method != "" {
|
||||
method := r.Method
|
||||
r.Method = strings.ToUpper(repl.ReplaceAll(rewr.Method, ""))
|
||||
if r.Method != method {
|
||||
rehandleNeeded = true
|
||||
}
|
||||
}
|
||||
|
||||
if rewr.URI != "" {
|
||||
// TODO: clean this all up, I don't think it's right
|
||||
|
||||
oldURI := r.RequestURI
|
||||
newURI := repl.ReplaceAll(rewr.URI, "")
|
||||
u, err := url.Parse(newURI)
|
||||
if err != nil {
|
||||
return caddyhttp.Error(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
r.RequestURI = newURI
|
||||
|
||||
r.URL.Path = u.Path
|
||||
if u.RawQuery != "" {
|
||||
r.URL.RawQuery = u.RawQuery
|
||||
}
|
||||
if u.Fragment != "" {
|
||||
r.URL.Fragment = u.Fragment
|
||||
}
|
||||
|
||||
if newURI != oldURI {
|
||||
rehandleNeeded = true
|
||||
}
|
||||
}
|
||||
|
||||
if rehandleNeeded && rewr.Rehandle {
|
||||
return caddyhttp.ErrRehandle
|
||||
}
|
||||
|
||||
return next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Interface guard
|
||||
var _ caddyhttp.MiddlewareHandler = (*Rewrite)(nil)
|
|
@ -65,23 +65,24 @@ func (routes RouteList) Provision(ctx caddy2.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// BuildCompositeRoute creates a chain of handlers by
|
||||
// applying all the matching routes.
|
||||
func (routes RouteList) BuildCompositeRoute(w http.ResponseWriter, r *http.Request) Handler {
|
||||
// BuildCompositeRoute creates a chain of handlers by applying all the matching
|
||||
// routes. The returned ResponseWriter should be used instead of rw.
|
||||
func (routes RouteList) BuildCompositeRoute(rw http.ResponseWriter, req *http.Request) (Handler, http.ResponseWriter) {
|
||||
mrw := &middlewareResponseWriter{ResponseWriterWrapper: &ResponseWriterWrapper{rw}}
|
||||
|
||||
if len(routes) == 0 {
|
||||
return emptyHandler
|
||||
return emptyHandler, mrw
|
||||
}
|
||||
|
||||
var mid []Middleware
|
||||
var responder Handler
|
||||
mrw := &middlewareResponseWriter{ResponseWriterWrapper: &ResponseWriterWrapper{w}}
|
||||
groups := make(map[string]struct{})
|
||||
|
||||
routeLoop:
|
||||
for _, route := range routes {
|
||||
// see if route matches
|
||||
for _, m := range route.matchers {
|
||||
if !m.Match(r) {
|
||||
if !m.Match(req) {
|
||||
continue routeLoop
|
||||
}
|
||||
}
|
||||
|
@ -102,15 +103,13 @@ routeLoop:
|
|||
|
||||
// apply the rest of the route
|
||||
for _, m := range route.middleware {
|
||||
mid = append(mid, func(next HandlerFunc) HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) error {
|
||||
// TODO: This is where request tracing could be implemented; also
|
||||
// see below to trace the responder as well
|
||||
// TODO: Trace a diff of the request, would be cool too! see what changed since the last middleware (host, headers, URI...)
|
||||
// TODO: see what the std lib gives us in terms of stack trracing too
|
||||
return m.ServeHTTP(mrw, r, next)
|
||||
}
|
||||
})
|
||||
// we have to be sure to wrap m outside
|
||||
// of our current scope so that the
|
||||
// reference to this m isn't overwritten
|
||||
// on the next iteration, leaving only
|
||||
// the last middleware in the chain as
|
||||
// the ONLY middleware in the chain!
|
||||
mid = append(mid, wrapMiddleware(m))
|
||||
}
|
||||
if responder == nil {
|
||||
responder = route.responder
|
||||
|
@ -132,7 +131,25 @@ routeLoop:
|
|||
stack = mid[i](stack)
|
||||
}
|
||||
|
||||
return stack
|
||||
return stack, mrw
|
||||
}
|
||||
|
||||
// wrapMiddleware wraps m such that it can be correctly
|
||||
// appended to a list of middleware. This is necessary
|
||||
// so that only the last middleware in a loop does not
|
||||
// become the only middleware of the stack, repeatedly
|
||||
// executed (i.e. it is necessary to keep a reference
|
||||
// to this m outside of the scope of a loop)!
|
||||
func wrapMiddleware(m MiddlewareHandler) Middleware {
|
||||
return func(next HandlerFunc) HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) error {
|
||||
// TODO: This is where request tracing could be implemented; also
|
||||
// see below to trace the responder as well
|
||||
// TODO: Trace a diff of the request, would be cool too! see what changed since the last middleware (host, headers, URI...)
|
||||
// TODO: see what the std lib gives us in terms of stack tracing too
|
||||
return m.ServeHTTP(w, r, next)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type middlewareResponseWriter struct {
|
||||
|
|
|
@ -32,14 +32,18 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// set up the replacer
|
||||
repl := newReplacer(r, w)
|
||||
// set up the context for the request
|
||||
repl := caddy2.NewReplacer()
|
||||
ctx := context.WithValue(r.Context(), caddy2.ReplacerCtxKey, repl)
|
||||
ctx = context.WithValue(ctx, TableCtxKey, make(map[string]interface{})) // TODO: Implement this
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
// once the pointer to the request won't change
|
||||
// anymore, finish setting up the replacer
|
||||
addHTTPVarsToReplacer(repl, r, w)
|
||||
|
||||
// build and execute the main handler chain
|
||||
stack := s.Routes.BuildCompositeRoute(w, r)
|
||||
stack, w := s.Routes.BuildCompositeRoute(w, r)
|
||||
err := s.executeCompositeRoute(w, r, stack)
|
||||
if err != nil {
|
||||
// add the raw error value to the request context
|
||||
|
@ -64,7 +68,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
w.WriteHeader(handlerErr.StatusCode)
|
||||
}
|
||||
} else {
|
||||
errStack := s.Errors.Routes.BuildCompositeRoute(w, r)
|
||||
errStack, w := s.Errors.Routes.BuildCompositeRoute(w, r)
|
||||
err := s.executeCompositeRoute(w, r, errStack)
|
||||
if err != nil {
|
||||
// TODO: what should we do if the error handler has an error?
|
||||
|
|
Loading…
Reference in a new issue