mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-14 23:06:27 +03:00
reverseproxy: Close hijacked conns on reload/quit (#4895)
* reverseproxy: Close hijacked conns on reload/quit We also send a Close control message to both ends of WebSocket connections. I have tested this many times in my dev environment with consistent success, although the variety of scenarios was limited. * Oops... actually call Close() this time * CloseMessage --> closeMessage Co-authored-by: Francis Lavoie <lavofr@gmail.com> * Use httpguts, duh * Use map instead of sync.Map Co-authored-by: Francis Lavoie <lavofr@gmail.com>
This commit is contained in:
parent
d3c3fa10bd
commit
66476d8c8f
2 changed files with 103 additions and 5 deletions
|
@ -192,6 +192,10 @@ type Handler struct {
|
||||||
// Holds the handle_response Caddyfile tokens while adapting
|
// Holds the handle_response Caddyfile tokens while adapting
|
||||||
handleResponseSegments []*caddyfile.Dispenser
|
handleResponseSegments []*caddyfile.Dispenser
|
||||||
|
|
||||||
|
// Stores upgraded requests (hijacked connections) for proper cleanup
|
||||||
|
connections map[io.ReadWriteCloser]openConnection
|
||||||
|
connectionsMu *sync.Mutex
|
||||||
|
|
||||||
ctx caddy.Context
|
ctx caddy.Context
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
events *caddyevents.App
|
events *caddyevents.App
|
||||||
|
@ -214,6 +218,8 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
||||||
h.events = eventAppIface.(*caddyevents.App)
|
h.events = eventAppIface.(*caddyevents.App)
|
||||||
h.ctx = ctx
|
h.ctx = ctx
|
||||||
h.logger = ctx.Logger(h)
|
h.logger = ctx.Logger(h)
|
||||||
|
h.connections = make(map[io.ReadWriteCloser]openConnection)
|
||||||
|
h.connectionsMu = new(sync.Mutex)
|
||||||
|
|
||||||
// verify SRV compatibility - TODO: LookupSRV deprecated; will be removed
|
// verify SRV compatibility - TODO: LookupSRV deprecated; will be removed
|
||||||
for i, v := range h.Upstreams {
|
for i, v := range h.Upstreams {
|
||||||
|
@ -407,16 +413,34 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cleanup cleans up the resources made by h during provisioning.
|
// Cleanup cleans up the resources made by h.
|
||||||
func (h *Handler) Cleanup() error {
|
func (h *Handler) Cleanup() error {
|
||||||
// TODO: Close keepalive connections on reload? https://github.com/caddyserver/caddy/pull/2507/files#diff-70219fd88fe3f36834f474ce6537ed26R762
|
// close hijacked connections (both to client and backend)
|
||||||
|
var err error
|
||||||
|
h.connectionsMu.Lock()
|
||||||
|
for _, oc := range h.connections {
|
||||||
|
if oc.gracefulClose != nil {
|
||||||
|
// this is potentially blocking while we have the lock on the connections
|
||||||
|
// map, but that should be OK since the server has in theory shut down
|
||||||
|
// and we are no longer using the connections map
|
||||||
|
gracefulErr := oc.gracefulClose()
|
||||||
|
if gracefulErr != nil && err == nil {
|
||||||
|
err = gracefulErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
closeErr := oc.conn.Close()
|
||||||
|
if closeErr != nil && err == nil {
|
||||||
|
err = closeErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.connectionsMu.Unlock()
|
||||||
|
|
||||||
// remove hosts from our config from the pool
|
// remove hosts from our config from the pool
|
||||||
for _, upstream := range h.Upstreams {
|
for _, upstream := range h.Upstreams {
|
||||||
_, _ = hosts.Delete(upstream.String())
|
_, _ = hosts.Delete(upstream.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
|
||||||
|
|
|
@ -20,6 +20,7 @@ package reverseproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -27,6 +28,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
"golang.org/x/net/http/httpguts"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWriter, req *http.Request, res *http.Response) {
|
||||||
|
@ -97,8 +99,26 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
errc := make(chan error, 1)
|
// Ensure the hijacked client connection, and the new connection established
|
||||||
|
// with the backend, are both closed in the event of a server shutdown. This
|
||||||
|
// is done by registering them. We also try to gracefully close connections
|
||||||
|
// we recognize as websockets.
|
||||||
|
gracefulClose := func(conn io.ReadWriteCloser) func() error {
|
||||||
|
if isWebsocket(req) {
|
||||||
|
return func() error {
|
||||||
|
return writeCloseControl(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
deleteFrontConn := h.registerConnection(conn, gracefulClose(conn))
|
||||||
|
deleteBackConn := h.registerConnection(backConn, gracefulClose(backConn))
|
||||||
|
defer deleteFrontConn()
|
||||||
|
defer deleteBackConn()
|
||||||
|
|
||||||
spc := switchProtocolCopier{user: conn, backend: backConn}
|
spc := switchProtocolCopier{user: conn, backend: backConn}
|
||||||
|
|
||||||
|
errc := make(chan error, 1)
|
||||||
go spc.copyToBackend(errc)
|
go spc.copyToBackend(errc)
|
||||||
go spc.copyFromBackend(errc)
|
go spc.copyFromBackend(errc)
|
||||||
<-errc
|
<-errc
|
||||||
|
@ -209,6 +229,60 @@ func (h Handler) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, er
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// registerConnection holds onto conn so it can be closed in the event
|
||||||
|
// of a server shutdown. This is useful because hijacked connections or
|
||||||
|
// connections dialed to backends don't close when server is shut down.
|
||||||
|
// The caller should call the returned delete() function when the
|
||||||
|
// connection is done to remove it from memory.
|
||||||
|
func (h *Handler) registerConnection(conn io.ReadWriteCloser, gracefulClose func() error) (del func()) {
|
||||||
|
h.connectionsMu.Lock()
|
||||||
|
h.connections[conn] = openConnection{conn, gracefulClose}
|
||||||
|
h.connectionsMu.Unlock()
|
||||||
|
return func() {
|
||||||
|
h.connectionsMu.Lock()
|
||||||
|
delete(h.connections, conn)
|
||||||
|
h.connectionsMu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeCloseControl sends a best-effort Close control message to the given
|
||||||
|
// WebSocket connection. Thanks to @pascaldekloe who provided inspiration
|
||||||
|
// from his simple implementation of this I was able to learn from at:
|
||||||
|
// github.com/pascaldekloe/websocket.
|
||||||
|
func writeCloseControl(conn io.Writer) error {
|
||||||
|
// https://github.com/pascaldekloe/websocket/blob/32050af67a5d/websocket.go#L119
|
||||||
|
|
||||||
|
var reason string // max 123 bytes (control frame payload limit is 125; status code takes 2)
|
||||||
|
const goingAway uint16 = 1001
|
||||||
|
|
||||||
|
// TODO: we might need to ensure we are the exclusive writer by this point (io.Copy is stopped)?
|
||||||
|
var writeBuf [127]byte
|
||||||
|
const closeMessage = 8
|
||||||
|
const finalBit = 1 << 7
|
||||||
|
writeBuf[0] = closeMessage | finalBit
|
||||||
|
writeBuf[1] = byte(len(reason) + 2)
|
||||||
|
binary.BigEndian.PutUint16(writeBuf[2:4], goingAway)
|
||||||
|
copy(writeBuf[4:], reason)
|
||||||
|
|
||||||
|
// simply best-effort, but return error for logging purposes
|
||||||
|
_, err := conn.Write(writeBuf[:4+len(reason)])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// isWebsocket returns true if r looks to be an upgrade request for WebSockets.
|
||||||
|
// It is a fairly naive check.
|
||||||
|
func isWebsocket(r *http.Request) bool {
|
||||||
|
return httpguts.HeaderValuesContainsToken(r.Header["Connection"], "upgrade") &&
|
||||||
|
httpguts.HeaderValuesContainsToken(r.Header["Upgrade"], "websocket")
|
||||||
|
}
|
||||||
|
|
||||||
|
// openConnection maps an open connection to
|
||||||
|
// an optional function for graceful close.
|
||||||
|
type openConnection struct {
|
||||||
|
conn io.ReadWriteCloser
|
||||||
|
gracefulClose func() error
|
||||||
|
}
|
||||||
|
|
||||||
type writeFlusher interface {
|
type writeFlusher interface {
|
||||||
io.Writer
|
io.Writer
|
||||||
http.Flusher
|
http.Flusher
|
||||||
|
@ -265,7 +339,7 @@ func (m *maxLatencyWriter) stop() {
|
||||||
// switchProtocolCopier exists so goroutines proxying data back and
|
// switchProtocolCopier exists so goroutines proxying data back and
|
||||||
// forth have nice names in stacks.
|
// forth have nice names in stacks.
|
||||||
type switchProtocolCopier struct {
|
type switchProtocolCopier struct {
|
||||||
user, backend io.ReadWriter
|
user, backend io.ReadWriteCloser
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
||||||
|
|
Loading…
Reference in a new issue