diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go index a5321d1b..a2c85f90 100644 --- a/modules/caddyhttp/reverseproxy/caddyfile.go +++ b/modules/caddyhttp/reverseproxy/caddyfile.go @@ -875,6 +875,26 @@ func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } h.WriteBufferSize = int(size) + case "read_timeout": + if !d.NextArg() { + return d.ArgErr() + } + timeout, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("invalid read timeout duration '%s': %v", d.Val(), err) + } + h.ReadTimeout = caddy.Duration(timeout) + + case "write_timeout": + if !d.NextArg() { + return d.ArgErr() + } + timeout, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("invalid write timeout duration '%s': %v", d.Val(), err) + } + h.WriteTimeout = caddy.Duration(timeout) + case "max_response_header": if !d.NextArg() { return d.ArgErr() diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index 7b573f81..ef72b886 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -30,6 +30,7 @@ import ( "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/modules/caddytls" + "go.uber.org/zap" "golang.org/x/net/http2" ) @@ -88,6 +89,12 @@ type HTTPTransport struct { // The size of the read buffer in bytes. Default: `4KiB`. ReadBufferSize int `json:"read_buffer_size,omitempty"` + // The maximum time to wait for next read from backend. Default: no timeout. + ReadTimeout caddy.Duration `json:"read_timeout,omitempty"` + + // The maximum time to wait for next write to backend. Default: no timeout. + WriteTimeout caddy.Duration `json:"write_timeout,omitempty"` + // The versions of HTTP to support. As a special case, "h2c" // can be specified to use H2C (HTTP/2 over Cleartext) to the // upstream (this feature is experimental and subject to @@ -147,7 +154,7 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error { } // NewTransport builds a standard-lib-compatible http.Transport value from h. -func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error) { +func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, error) { // Set keep-alive defaults if it wasn't otherwise configured if h.KeepAlive == nil { h.KeepAlive = &KeepAlive{ @@ -194,6 +201,7 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error) network = dialInfo.Network address = dialInfo.Address } + conn, err := dialer.DialContext(ctx, network, address) if err != nil { // identify this error as one that occurred during @@ -201,6 +209,18 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error) // decide whether to retry a request return nil, DialError{err} } + + // if read/write timeouts are configured and this is a TCP connection, enforce the timeouts + // by wrapping the connection with our own type + if tcpConn, ok := conn.(*net.TCPConn); ok && (h.ReadTimeout > 0 || h.WriteTimeout > 0) { + conn = &tcpRWTimeoutConn{ + TCPConn: tcpConn, + readTimeout: time.Duration(h.ReadTimeout), + writeTimeout: time.Duration(h.WriteTimeout), + logger: caddyCtx.Logger(h), + } + } + return conn, nil }, MaxConnsPerHost: h.MaxConnsPerHost, @@ -214,7 +234,7 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error) if h.TLS != nil { rt.TLSHandshakeTimeout = time.Duration(h.TLS.HandshakeTimeout) var err error - rt.TLSClientConfig, err = h.TLS.MakeTLSClientConfig(ctx) + rt.TLSClientConfig, err = h.TLS.MakeTLSClientConfig(caddyCtx) if err != nil { return nil, fmt.Errorf("making TLS client config: %v", err) } @@ -510,6 +530,36 @@ type KeepAlive struct { IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` } +// tcpRWTimeoutConn enforces read/write timeouts for a TCP connection. +// If it fails to set deadlines, the error is logged but does not abort +// the read/write attempt (ignoring the error is consistent with what +// the standard library does: https://github.com/golang/go/blob/c5da4fb7ac5cb7434b41fc9a1df3bee66c7f1a4d/src/net/http/server.go#L981-L986) +type tcpRWTimeoutConn struct { + *net.TCPConn + readTimeout, writeTimeout time.Duration + logger *zap.Logger +} + +func (c *tcpRWTimeoutConn) Read(b []byte) (int, error) { + if c.readTimeout > 0 { + err := c.TCPConn.SetReadDeadline(time.Now().Add(c.readTimeout)) + if err != nil { + c.logger.Error("failed to set read deadline", zap.Error(err)) + } + } + return c.TCPConn.Read(b) +} + +func (c *tcpRWTimeoutConn) Write(b []byte) (int, error) { + if c.writeTimeout > 0 { + err := c.TCPConn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + if err != nil { + c.logger.Error("failed to set write deadline", zap.Error(err)) + } + } + return c.TCPConn.Write(b) +} + // decodeBase64DERCert base64-decodes, then DER-decodes, certStr. func decodeBase64DERCert(certStr string) (*x509.Certificate, error) { // decode base64