reverseproxy: Experimental streaming timeouts (#5567)

* reverseproxy: WIP streaming timeouts

* More verbose logging by using the child logger

* reverseproxy: Implement streaming timeouts

* reverseproxy: Refactor cleanup

* reverseproxy: Avoid **time.Timer

---------

Co-authored-by: Francis Lavoie <lavofr@gmail.com>
This commit is contained in:
mmm444 2023-06-19 23:54:43 +02:00 committed by GitHub
parent 4548b7de8e
commit 424ae0f420
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 133 additions and 34 deletions

View file

@ -87,6 +87,8 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
// buffer_requests // buffer_requests
// buffer_responses // buffer_responses
// max_buffer_size <size> // max_buffer_size <size>
// stream_timeout <duration>
// stream_close_delay <duration>
// //
// # request manipulation // # request manipulation
// trusted_proxies [private_ranges] <ranges...> // trusted_proxies [private_ranges] <ranges...>
@ -571,6 +573,34 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: max_buffer_size: use request_buffers and/or response_buffers instead (with maximum buffer sizes)") caddy.Log().Named("config.adapter.caddyfile").Warn("DEPRECATED: max_buffer_size: use request_buffers and/or response_buffers instead (with maximum buffer sizes)")
h.DeprecatedMaxBufferSize = int64(size) h.DeprecatedMaxBufferSize = int64(size)
case "stream_timeout":
if !d.NextArg() {
return d.ArgErr()
}
if fi, err := strconv.Atoi(d.Val()); err == nil {
h.StreamTimeout = caddy.Duration(fi)
} else {
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("bad duration value '%s': %v", d.Val(), err)
}
h.StreamTimeout = caddy.Duration(dur)
}
case "stream_close_delay":
if !d.NextArg() {
return d.ArgErr()
}
if fi, err := strconv.Atoi(d.Val()); err == nil {
h.StreamCloseDelay = caddy.Duration(fi)
} else {
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("bad duration value '%s': %v", d.Val(), err)
}
h.StreamCloseDelay = caddy.Duration(dur)
}
case "trusted_proxies": case "trusted_proxies":
for d.NextArg() { for d.NextArg() {
if d.Val() == "private_ranges" { if d.Val() == "private_ranges" {

View file

@ -157,6 +157,19 @@ type Handler struct {
// could be useful if the backend has tighter memory constraints. // could be useful if the backend has tighter memory constraints.
ResponseBuffers int64 `json:"response_buffers,omitempty"` ResponseBuffers int64 `json:"response_buffers,omitempty"`
// If nonzero, streaming requests such as WebSockets will be
// forcibly closed at the end of the timeout. Default: no timeout.
StreamTimeout caddy.Duration `json:"stream_timeout,omitempty"`
// If nonzero, streaming requests such as WebSockets will not be
// closed when the proxy config is unloaded, and instead the stream
// will remain open until the delay is complete. In other words,
// enabling this prevents streams from closing when Caddy's config
// is reloaded. Enabling this may be a good idea to avoid a thundering
// herd of reconnecting clients which had their connections closed
// by the previous config closing. Default: no delay.
StreamCloseDelay caddy.Duration `json:"stream_close_delay,omitempty"`
// If configured, rewrites the copy of the upstream request. // If configured, rewrites the copy of the upstream request.
// Allows changing the request method and URI (path and query). // Allows changing the request method and URI (path and query).
// Since the rewrite is applied to the copy, it does not persist // Since the rewrite is applied to the copy, it does not persist
@ -199,6 +212,7 @@ type Handler struct {
// Stores upgraded requests (hijacked connections) for proper cleanup // Stores upgraded requests (hijacked connections) for proper cleanup
connections map[io.ReadWriteCloser]openConnection connections map[io.ReadWriteCloser]openConnection
connectionsCloseTimer *time.Timer
connectionsMu *sync.Mutex connectionsMu *sync.Mutex
ctx caddy.Context ctx caddy.Context
@ -382,25 +396,7 @@ func (h *Handler) Provision(ctx caddy.Context) error {
// Cleanup cleans up the resources made by h. // Cleanup cleans up the resources made by h.
func (h *Handler) Cleanup() error { func (h *Handler) Cleanup() error {
// close hijacked connections (both to client and backend) err := h.cleanupConnections()
var err error
h.connectionsMu.Lock()
for _, oc := range h.connections {
if oc.gracefulClose != nil {
// this is potentially blocking while we have the lock on the connections
// map, but that should be OK since the server has in theory shut down
// and we are no longer using the connections map
gracefulErr := oc.gracefulClose()
if gracefulErr != nil && err == nil {
err = gracefulErr
}
}
closeErr := oc.conn.Close()
if closeErr != nil && err == nil {
err = closeErr
}
}
h.connectionsMu.Unlock()
// remove hosts from our config from the pool // remove hosts from our config from the pool
for _, upstream := range h.Upstreams { for _, upstream := range h.Upstreams {
@ -872,7 +868,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
repl.Set("http.reverse_proxy.status_code", res.StatusCode) repl.Set("http.reverse_proxy.status_code", res.StatusCode)
repl.Set("http.reverse_proxy.status_text", res.Status) repl.Set("http.reverse_proxy.status_text", res.Status)
h.logger.Debug("handling response", zap.Int("handler", i)) logger.Debug("handling response", zap.Int("handler", i))
// we make some data available via request context to child routes // we make some data available via request context to child routes
// so that they may inherit some options and functions from the // so that they may inherit some options and functions from the
@ -917,7 +913,7 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
} }
// finalizeResponse prepares and copies the response. // finalizeResponse prepares and copies the response.
func (h Handler) finalizeResponse( func (h *Handler) finalizeResponse(
rw http.ResponseWriter, rw http.ResponseWriter,
req *http.Request, req *http.Request,
res *http.Response, res *http.Response,
@ -967,7 +963,7 @@ func (h Handler) finalizeResponse(
// there's nothing an error handler can do to recover at this point; // there's nothing an error handler can do to recover at this point;
// the standard lib's proxy panics at this point, but we'll just log // the standard lib's proxy panics at this point, but we'll just log
// the error and abort the stream here // the error and abort the stream here
h.logger.Error("aborting with incomplete response", zap.Error(err)) logger.Error("aborting with incomplete response", zap.Error(err))
return nil return nil
} }

View file

@ -33,19 +33,19 @@ import (
"golang.org/x/net/http/httpguts" "golang.org/x/net/http/httpguts"
) )
func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) { func (h *Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := upgradeType(req.Header) reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header) resUpType := upgradeType(res.Header)
// Taken from https://github.com/golang/go/commit/5c489514bc5e61ad9b5b07bd7d8ec65d66a0512a // Taken from https://github.com/golang/go/commit/5c489514bc5e61ad9b5b07bd7d8ec65d66a0512a
// We know reqUpType is ASCII, it's checked by the caller. // We know reqUpType is ASCII, it's checked by the caller.
if !asciiIsPrint(resUpType) { if !asciiIsPrint(resUpType) {
h.logger.Debug("backend tried to switch to invalid protocol", logger.Debug("backend tried to switch to invalid protocol",
zap.String("backend_upgrade", resUpType)) zap.String("backend_upgrade", resUpType))
return return
} }
if !asciiEqualFold(reqUpType, resUpType) { if !asciiEqualFold(reqUpType, resUpType) {
h.logger.Debug("backend tried to switch to unexpected protocol via Upgrade header", logger.Debug("backend tried to switch to unexpected protocol via Upgrade header",
zap.String("backend_upgrade", resUpType), zap.String("backend_upgrade", resUpType),
zap.String("requested_upgrade", reqUpType)) zap.String("requested_upgrade", reqUpType))
return return
@ -53,12 +53,12 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
hj, ok := rw.(http.Hijacker) hj, ok := rw.(http.Hijacker)
if !ok { if !ok {
h.logger.Sugar().Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw) logger.Error("can't switch protocols using non-Hijacker ResponseWriter", zap.String("type", fmt.Sprintf("%T", rw)))
return return
} }
backConn, ok := res.Body.(io.ReadWriteCloser) backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok { if !ok {
h.logger.Error("internal error: 101 switching protocols response with non-writable body") logger.Error("internal error: 101 switching protocols response with non-writable body")
return return
} }
@ -83,7 +83,7 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
logger.Debug("upgrading connection") logger.Debug("upgrading connection")
conn, brw, err := hj.Hijack() conn, brw, err := hj.Hijack()
if err != nil { if err != nil {
h.logger.Error("hijack failed on protocol switch", zap.Error(err)) logger.Error("hijack failed on protocol switch", zap.Error(err))
return return
} }
@ -94,7 +94,7 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
}() }()
if err := brw.Flush(); err != nil { if err := brw.Flush(); err != nil {
h.logger.Debug("response flush", zap.Error(err)) logger.Debug("response flush", zap.Error(err))
return return
} }
@ -120,10 +120,23 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
spc := switchProtocolCopier{user: conn, backend: backConn} spc := switchProtocolCopier{user: conn, backend: backConn}
// setup the timeout if requested
var timeoutc <-chan time.Time
if h.StreamTimeout > 0 {
timer := time.NewTimer(time.Duration(h.StreamTimeout))
defer timer.Stop()
timeoutc = timer.C
}
errc := make(chan error, 1) errc := make(chan error, 1)
go spc.copyToBackend(errc) go spc.copyToBackend(errc)
go spc.copyFromBackend(errc) go spc.copyFromBackend(errc)
<-errc select {
case err := <-errc:
logger.Debug("streaming error", zap.Error(err))
case time := <-timeoutc:
logger.Debug("stream timed out", zap.Time("timeout", time))
}
} }
// flushInterval returns the p.FlushInterval value, conditionally // flushInterval returns the p.FlushInterval value, conditionally
@ -243,10 +256,70 @@ func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func
return func() { return func() {
h.connectionsMu.Lock() h.connectionsMu.Lock()
delete(h.connections, conn) delete(h.connections, conn)
// if there is no connection left before the connections close timer fires
if len(h.connections) == 0 && h.connectionsCloseTimer != nil {
// we release the timer that holds the reference to Handler
if (*h.connectionsCloseTimer).Stop() {
h.logger.Debug("stopped streaming connections close timer - all connections are already closed")
}
h.connectionsCloseTimer = nil
}
h.connectionsMu.Unlock() h.connectionsMu.Unlock()
} }
} }
// closeConnections immediately closes all hijacked connections (both to client and backend).
func (h *Handler) closeConnections() error {
var err error
h.connectionsMu.Lock()
defer h.connectionsMu.Unlock()
for _, oc := range h.connections {
if oc.gracefulClose != nil {
// this is potentially blocking while we have the lock on the connections
// map, but that should be OK since the server has in theory shut down
// and we are no longer using the connections map
gracefulErr := oc.gracefulClose()
if gracefulErr != nil && err == nil {
err = gracefulErr
}
}
closeErr := oc.conn.Close()
if closeErr != nil && err == nil {
err = closeErr
}
}
return err
}
// cleanupConnections closes hijacked connections.
// Depending on the value of StreamCloseDelay it does that either immediately
// or sets up a timer that will do that later.
func (h *Handler) cleanupConnections() error {
if h.StreamCloseDelay == 0 {
return h.closeConnections()
}
h.connectionsMu.Lock()
defer h.connectionsMu.Unlock()
// the handler is shut down, no new connection can appear,
// so we can skip setting up the timer when there are no connections
if len(h.connections) > 0 {
delay := time.Duration(h.StreamCloseDelay)
h.connectionsCloseTimer = time.AfterFunc(delay, func() {
h.logger.Debug("closing streaming connections after delay",
zap.Duration("delay", delay))
err := h.closeConnections()
if err != nil {
h.logger.Error("failed to closed connections after delay",
zap.Error(err),
zap.Duration("delay", delay))
}
})
}
return nil
}
// writeCloseControl sends a best-effort Close control message to the given // writeCloseControl sends a best-effort Close control message to the given
// WebSocket connection. Thanks to @pascaldekloe who provided inspiration // WebSocket connection. Thanks to @pascaldekloe who provided inspiration
// from his simple implementation of this I was able to learn from at: // from his simple implementation of this I was able to learn from at: