diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index bf11ab283..83c48b9b2 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -25,6 +25,7 @@ import ( "net/http" "os" "reflect" + "strings" "time" "github.com/caddyserver/caddy/v2" @@ -242,10 +243,45 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error) return rt, nil } +// replaceTLSServername checks TLS servername to see if it needs replacing +// if it does need replacing, it creates a new cloned HTTPTransport object to avoid any races +// and does the replacing of the TLS servername on that and returns the new object +// if no replacement is necessary it returns the original +func (h *HTTPTransport) replaceTLSServername(repl *caddy.Replacer) *HTTPTransport { + // check whether we have TLS and need to replace the servername in the TLSClientConfig + if h.TLSEnabled() && strings.Contains(h.TLS.ServerName, "{") { + // make a new h, "copy" the parts we don't need to touch, add a new *tls.Config and replace servername + newtransport := &HTTPTransport{ + Resolver: h.Resolver, + TLS: h.TLS, + KeepAlive: h.KeepAlive, + Compression: h.Compression, + MaxConnsPerHost: h.MaxConnsPerHost, + DialTimeout: h.DialTimeout, + FallbackDelay: h.FallbackDelay, + ResponseHeaderTimeout: h.ResponseHeaderTimeout, + ExpectContinueTimeout: h.ExpectContinueTimeout, + MaxResponseHeaderSize: h.MaxResponseHeaderSize, + WriteBufferSize: h.WriteBufferSize, + ReadBufferSize: h.ReadBufferSize, + Versions: h.Versions, + Transport: h.Transport.Clone(), + h2cTransport: h.h2cTransport, + } + newtransport.Transport.TLSClientConfig.ServerName = repl.ReplaceAll(newtransport.Transport.TLSClientConfig.ServerName, "") + return newtransport + } + + return h +} + // RoundTrip implements http.RoundTripper. func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Try to replace TLS servername if needed + repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) + transport := h.replaceTLSServername(repl) - h.SetScheme(req) + transport.SetScheme(req) // if H2C ("HTTP/2 over cleartext") is enabled and the upstream request is // HTTP without TLS, use the alternate H2C-capable transport instead @@ -253,7 +289,7 @@ func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { return h.h2cTransport.RoundTrip(req) } - return h.Transport.RoundTrip(req) + return transport.Transport.RoundTrip(req) } // SetScheme ensures that the outbound request req diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 0364d1a7d..1068c23da 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -716,38 +716,6 @@ func (h Handler) addForwardedHeaders(req *http.Request) error { return nil } -// replaceTLSServername checks TLS servername to see if it needs replacing -// if it does need replacing, it creates a new cloned HTTPTransport object to avoid any races -// and does the replacing of the TLS servername on that and returns the new object -// if no replacement is necessary it returns the original -func (h *Handler) replaceTLSServername(transport *HTTPTransport, repl *caddy.Replacer) *HTTPTransport { - // check whether we have TLS and need to replace the servername in the TLSClientConfig - if transport.TLSEnabled() && strings.Contains(transport.TLS.ServerName, "{") { - // make a new transport, "copy" the parts we don't need to touch, add a new *tls.Config and replace servername - newtransport := &HTTPTransport{ - Resolver: transport.Resolver, - TLS: transport.TLS, - KeepAlive: transport.KeepAlive, - Compression: transport.Compression, - MaxConnsPerHost: transport.MaxConnsPerHost, - DialTimeout: transport.DialTimeout, - FallbackDelay: transport.FallbackDelay, - ResponseHeaderTimeout: transport.ResponseHeaderTimeout, - ExpectContinueTimeout: transport.ExpectContinueTimeout, - MaxResponseHeaderSize: transport.MaxResponseHeaderSize, - WriteBufferSize: transport.WriteBufferSize, - ReadBufferSize: transport.ReadBufferSize, - Versions: transport.Versions, - Transport: transport.Transport.Clone(), - h2cTransport: transport.h2cTransport, - } - newtransport.Transport.TLSClientConfig.ServerName = repl.ReplaceAll(newtransport.Transport.TLSClientConfig.ServerName, "") - return newtransport - } - - return transport -} - // reverseProxy performs a round-trip to the given backend and processes the response with the client. // (This method is mostly the beginning of what was borrowed from the net/http/httputil package in the // Go standard library which was used as the foundation.) @@ -762,18 +730,10 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe server := req.Context().Value(caddyhttp.ServerCtxKey).(*caddyhttp.Server) shouldLogCredentials := server.Logs != nil && server.Logs.ShouldLogCredentials - // Default to using the transport configured during provisioning stage - transport := h.Transport - - // If we have a HTTP transport, try to replace the TLS servername - if tmpTransport, ok := transport.(*HTTPTransport); ok { - transport = h.replaceTLSServername(tmpTransport, repl) - } - // do the round-trip; emit debug log with values we know are // safe, or if there is no error, emit fuller log entry start := time.Now() - res, err := transport.RoundTrip(req) + res, err := h.Transport.RoundTrip(req) duration := time.Since(start) logger := h.logger.With( zap.String("upstream", di.Upstream.String()),