Fix rehandling bug

This commit is contained in:
Matthew Holt 2019-07-11 22:02:47 -06:00
parent 4698352b20
commit 9722dbe18a
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5
2 changed files with 33 additions and 15 deletions

View file

@ -85,6 +85,10 @@ func (app *App) Provision(ctx caddy.Context) error {
return fmt.Errorf("setting up server error handling routes: %v", err)
}
}
if srv.MaxRehandles == nil {
srv.MaxRehandles = &DefaultMaxRehandles
}
}
return nil
@ -111,8 +115,8 @@ func (app *App) Validate() error {
// each server's max rehandle value must be valid
for srvName, srv := range app.Servers {
if srv.MaxRehandles < 0 {
return fmt.Errorf("%s: invalid max_rehandles value: %d", srvName, srv.MaxRehandles)
if srv.MaxRehandles != nil && *srv.MaxRehandles < 0 {
return fmt.Errorf("%s: invalid max_rehandles value: %d", srvName, *srv.MaxRehandles)
}
}
@ -435,6 +439,10 @@ const (
DefaultHTTPSPort = 443
)
// DefaultMaxRehandles is the maximum number of rehandles to
// allow, if not specified explicitly.
var DefaultMaxRehandles = 3
// Interface guards
var (
_ caddy.App = (*App)(nil)

View file

@ -39,7 +39,7 @@ type Server struct {
Errors *HTTPErrorConfig `json:"errors,omitempty"`
TLSConnPolicies caddytls.ConnectionPolicies `json:"tls_connection_policies,omitempty"`
AutoHTTPS *AutoHTTPSConfig `json:"automatic_https,omitempty"`
MaxRehandles int `json:"max_rehandles,omitempty"`
MaxRehandles *int `json:"max_rehandles,omitempty"`
StrictSNIHost bool `json:"strict_sni_host,omitempty"` // TODO: see if we can turn this on by default when clientauth is configured
tlsApp *caddytls.TLS
@ -65,9 +65,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
addHTTPVarsToReplacer(repl, r, w)
// build and execute the main handler chain
stack := s.Routes.BuildCompositeRoute(r)
stack = s.wrapPrimaryRoute(stack)
err := s.executeCompositeRoute(w, r, stack)
err := s.executeCompositeRoute(w, r, s.Routes)
if err != nil {
// add the raw error value to the request context
// so it can be accessed by error handlers
@ -85,8 +83,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
if s.Errors != nil && len(s.Errors.Routes) > 0 {
errStack := s.Errors.Routes.BuildCompositeRoute(r)
err := s.executeCompositeRoute(w, r, errStack)
err := s.executeCompositeRoute(w, r, s.Errors.Routes)
if err != nil {
// TODO: what should we do if the error handler has an error?
log.Printf("[ERROR] [%s %s] handling error: %v", r.Method, r.RequestURI, err)
@ -103,20 +100,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
// executeCompositeRoute executes stack with w and r. This function handles
// the special ErrRehandle error value, which reprocesses requests through
// the stack again. Any error value returned from this function would be an
// actual error that needs to be handled.
func (s *Server) executeCompositeRoute(w http.ResponseWriter, r *http.Request, stack Handler) error {
// executeCompositeRoute compiles a composite route from routeList and executes
// it using w and r. This function handles the sentinel ErrRehandle error value,
// which reprocesses requests through the stack again. Any error value returned
// from this function would be an actual error that needs to be handled.
func (s *Server) executeCompositeRoute(w http.ResponseWriter, r *http.Request, routeList RouteList) error {
maxRehandles := 0
if s.MaxRehandles != nil {
maxRehandles = *s.MaxRehandles
}
var err error
for i := -1; i <= s.MaxRehandles; i++ {
for i := -1; i <= maxRehandles; i++ {
// we started the counter at -1 because we
// always want to run this at least once
// the purpose of rehandling is often to give
// matchers a chance to re-evaluate on the
// changed version of the request, so compile
// the handler stack anew in each iteration
stack := routeList.BuildCompositeRoute(r)
stack = s.wrapPrimaryRoute(stack)
// only loop if rehandling is required
err = stack.ServeHTTP(w, r)
if err != ErrRehandle {
break
}
if i >= s.MaxRehandles-1 {
if i >= maxRehandles-1 {
return fmt.Errorf("too many rehandles")
}
}