Moved SNI servername replacement into httptransport.

This commit is contained in:
Kiss Károly 2022-06-14 13:42:44 +02:00
parent a661daff98
commit afdf87bc08
2 changed files with 39 additions and 43 deletions

View file

@ -25,6 +25,7 @@ import (
"net/http"
"os"
"reflect"
"strings"
"time"
"github.com/caddyserver/caddy/v2"
@ -242,10 +243,45 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error)
return rt, nil
}
// replaceTLSServername checks TLS servername to see if it needs replacing
// if it does need replacing, it creates a new cloned HTTPTransport object to avoid any races
// and does the replacing of the TLS servername on that and returns the new object
// if no replacement is necessary it returns the original
func (h *HTTPTransport) replaceTLSServername(repl *caddy.Replacer) *HTTPTransport {
// check whether we have TLS and need to replace the servername in the TLSClientConfig
if h.TLSEnabled() && strings.Contains(h.TLS.ServerName, "{") {
// make a new h, "copy" the parts we don't need to touch, add a new *tls.Config and replace servername
newtransport := &HTTPTransport{
Resolver: h.Resolver,
TLS: h.TLS,
KeepAlive: h.KeepAlive,
Compression: h.Compression,
MaxConnsPerHost: h.MaxConnsPerHost,
DialTimeout: h.DialTimeout,
FallbackDelay: h.FallbackDelay,
ResponseHeaderTimeout: h.ResponseHeaderTimeout,
ExpectContinueTimeout: h.ExpectContinueTimeout,
MaxResponseHeaderSize: h.MaxResponseHeaderSize,
WriteBufferSize: h.WriteBufferSize,
ReadBufferSize: h.ReadBufferSize,
Versions: h.Versions,
Transport: h.Transport.Clone(),
h2cTransport: h.h2cTransport,
}
newtransport.Transport.TLSClientConfig.ServerName = repl.ReplaceAll(newtransport.Transport.TLSClientConfig.ServerName, "")
return newtransport
}
return h
}
// RoundTrip implements http.RoundTripper.
func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Try to replace TLS servername if needed
repl := req.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
transport := h.replaceTLSServername(repl)
h.SetScheme(req)
transport.SetScheme(req)
// if H2C ("HTTP/2 over cleartext") is enabled and the upstream request is
// HTTP without TLS, use the alternate H2C-capable transport instead
@ -253,7 +289,7 @@ func (h *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return h.h2cTransport.RoundTrip(req)
}
return h.Transport.RoundTrip(req)
return transport.Transport.RoundTrip(req)
}
// SetScheme ensures that the outbound request req

View file

@ -716,38 +716,6 @@ func (h Handler) addForwardedHeaders(req *http.Request) error {
return nil
}
// replaceTLSServername checks TLS servername to see if it needs replacing
// if it does need replacing, it creates a new cloned HTTPTransport object to avoid any races
// and does the replacing of the TLS servername on that and returns the new object
// if no replacement is necessary it returns the original
func (h *Handler) replaceTLSServername(transport *HTTPTransport, repl *caddy.Replacer) *HTTPTransport {
// check whether we have TLS and need to replace the servername in the TLSClientConfig
if transport.TLSEnabled() && strings.Contains(transport.TLS.ServerName, "{") {
// make a new transport, "copy" the parts we don't need to touch, add a new *tls.Config and replace servername
newtransport := &HTTPTransport{
Resolver: transport.Resolver,
TLS: transport.TLS,
KeepAlive: transport.KeepAlive,
Compression: transport.Compression,
MaxConnsPerHost: transport.MaxConnsPerHost,
DialTimeout: transport.DialTimeout,
FallbackDelay: transport.FallbackDelay,
ResponseHeaderTimeout: transport.ResponseHeaderTimeout,
ExpectContinueTimeout: transport.ExpectContinueTimeout,
MaxResponseHeaderSize: transport.MaxResponseHeaderSize,
WriteBufferSize: transport.WriteBufferSize,
ReadBufferSize: transport.ReadBufferSize,
Versions: transport.Versions,
Transport: transport.Transport.Clone(),
h2cTransport: transport.h2cTransport,
}
newtransport.Transport.TLSClientConfig.ServerName = repl.ReplaceAll(newtransport.Transport.TLSClientConfig.ServerName, "")
return newtransport
}
return transport
}
// reverseProxy performs a round-trip to the given backend and processes the response with the client.
// (This method is mostly the beginning of what was borrowed from the net/http/httputil package in the
// Go standard library which was used as the foundation.)
@ -762,18 +730,10 @@ func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, origRe
server := req.Context().Value(caddyhttp.ServerCtxKey).(*caddyhttp.Server)
shouldLogCredentials := server.Logs != nil && server.Logs.ShouldLogCredentials
// Default to using the transport configured during provisioning stage
transport := h.Transport
// If we have a HTTP transport, try to replace the TLS servername
if tmpTransport, ok := transport.(*HTTPTransport); ok {
transport = h.replaceTLSServername(tmpTransport, repl)
}
// do the round-trip; emit debug log with values we know are
// safe, or if there is no error, emit fuller log entry
start := time.Now()
res, err := transport.RoundTrip(req)
res, err := h.Transport.RoundTrip(req)
duration := time.Since(start)
logger := h.logger.With(
zap.String("upstream", di.Upstream.String()),