diff --git a/listeners.go b/listeners.go index ae1873f1..0bbee8a3 100644 --- a/listeners.go +++ b/listeners.go @@ -386,6 +386,19 @@ func JoinNetworkAddress(network, host, port string) string { return a } +// ListenerWrapper is a type that wraps a listener +// so it can modify the input listener's methods. +// Modules that implement this interface are found +// in the caddy.listeners namespace. Usually, to +// wrap a listener, you will define your own struct +// type that embeds the input listener, then +// implement your own methods that you want to wrap, +// calling the underlying listener's methods where +// appropriate. +type ListenerWrapper interface { + WrapListener(net.Listener) net.Listener +} + var ( listeners = make(map[string]*globalListener) listenersMu sync.Mutex diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go index 99f215f1..718025eb 100644 --- a/modules/caddyhttp/caddyhttp.go +++ b/modules/caddyhttp/caddyhttp.go @@ -40,6 +40,11 @@ func init() { if err != nil { caddy.Log().Fatal(err.Error()) } + + err = caddy.RegisterModule(tlsPlaceholderWrapper{}) + if err != nil { + caddy.Log().Fatal(err.Error()) + } } // App is a robust, production-ready HTTP server. @@ -181,6 +186,7 @@ func (app *App) Provision(ctx caddy.Context) error { srv.StrictSNIHost = &trueBool } + // process each listener address for i := range srv.Listen { lnOut, err := repl.ReplaceOrErr(srv.Listen[i], true, true) if err != nil { @@ -190,6 +196,37 @@ func (app *App) Provision(ctx caddy.Context) error { srv.Listen[i] = lnOut } + // set up each listener modifier + if srv.ListenerWrappersRaw != nil { + vals, err := ctx.LoadModule(srv, "ListenerWrappersRaw") + if err != nil { + return fmt.Errorf("loading listener wrapper modules: %v", err) + } + var hasTLSPlaceholder bool + for i, val := range vals.([]interface{}) { + if _, ok := val.(*tlsPlaceholderWrapper); ok { + if i == 0 { + // putting the tls placeholder wrapper first is nonsensical because + // that is the default, implicit setting: without it, all wrappers + // will go after the TLS listener anyway + return fmt.Errorf("it is unnecessary to specify the TLS listener wrapper in the first position because that is the default") + } + if hasTLSPlaceholder { + return fmt.Errorf("TLS listener wrapper can only be specified once") + } + hasTLSPlaceholder = true + } + srv.listenerWrappers = append(srv.listenerWrappers, val.(caddy.ListenerWrapper)) + } + // if any wrappers were configured but the TLS placeholder wrapper is + // absent, prepend it so all defined wrappers come after the TLS + // handshake; this simplifies logic when starting the server, since we + // can simply assume the TLS placeholder will always be there + if !hasTLSPlaceholder && len(srv.listenerWrappers) > 0 { + srv.listenerWrappers = append([]caddy.ListenerWrapper{new(tlsPlaceholderWrapper)}, srv.listenerWrappers...) + } + } + // pre-compile the primary handler chain, and be sure to wrap it in our // route handler so that important security checks are done, etc. primaryRoute := emptyHandler @@ -265,12 +302,23 @@ func (app *App) Start() error { return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err) } for portOffset := uint(0); portOffset < listenAddr.PortRangeSize(); portOffset++ { + // create the listener for this socket hostport := listenAddr.JoinHostPort(portOffset) ln, err := caddy.Listen(listenAddr.Network, hostport) if err != nil { return fmt.Errorf("%s: listening on %s: %v", listenAddr.Network, hostport, err) } + // wrap listener before TLS (up to the TLS placeholder wrapper) + var lnWrapperIdx int + for i, lnWrapper := range srv.listenerWrappers { + if _, ok := lnWrapper.(*tlsPlaceholderWrapper); ok { + lnWrapperIdx = i + 1 // mark the next wrapper's spot + break + } + ln = lnWrapper.WrapListener(ln) + } + // enable TLS if there is a policy and if this is not the HTTP port useTLS := len(srv.TLSConnPolicies) > 0 && int(listenAddr.StartPort+portOffset) != app.httpPort() if useTLS { @@ -303,6 +351,11 @@ func (app *App) Start() error { ///////// } + // finish wrapping listener where we left off before TLS + for i := lnWrapperIdx; i < len(srv.listenerWrappers); i++ { + ln = srv.listenerWrappers[i].WrapListener(ln) + } + app.logger.Debug("starting server loop", zap.String("address", lnAddr), zap.Bool("http3", srv.ExperimentalHTTP3), @@ -544,6 +597,19 @@ func StatusCodeMatches(actual, configured int) bool { return false } +// tlsPlaceholderWrapper is a no-op listener wrapper that marks +// where the TLS listener should be in a chain of listener wrappers. +type tlsPlaceholderWrapper struct{} + +func (tlsPlaceholderWrapper) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "caddy.listeners.tls", + New: func() caddy.Module { return new(tlsPlaceholderWrapper) }, + } +} + +func (tlsPlaceholderWrapper) WrapListener(ln net.Listener) net.Listener { return ln } + const ( // DefaultHTTPPort is the default port for HTTP. DefaultHTTPPort = 80 @@ -557,4 +623,6 @@ var ( _ caddy.App = (*App)(nil) _ caddy.Provisioner = (*App)(nil) _ caddy.Validator = (*App)(nil) + + _ caddy.ListenerWrapper = (*tlsPlaceholderWrapper)(nil) ) diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index 1e22790d..461865c5 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -16,6 +16,7 @@ package caddyhttp import ( "context" + "encoding/json" "fmt" "net" "net/http" @@ -38,6 +39,10 @@ type Server struct { // that may include port ranges. Listen []string `json:"listen,omitempty"` + // A list of listener wrapper modules, which can modify the behavior + // of the base listener. They are applied in the given order. + ListenerWrappersRaw []json.RawMessage `json:"listener_wrappers,omitempty" caddy:"namespace=caddy.listeners inline_key=wrapper"` + // How long to allow a read from a client's upload. Setting this // to a short, non-zero value can mitigate slowloris attacks, but // may also affect legitimately slow clients. @@ -106,6 +111,7 @@ type Server struct { primaryHandlerChain Handler errorHandlerChain Handler + listenerWrappers []caddy.ListenerWrapper tlsApp *caddytls.TLS logger *zap.Logger