mirror of
https://github.com/caddyserver/caddy.git
synced 2025-02-01 14:48:22 +03:00
proxy: Fixed support for TLS verification of WebSocket connections
This commit is contained in:
parent
153d4a5ac6
commit
b857265f9c
1 changed files with 45 additions and 24 deletions
|
@ -349,9 +349,14 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
||||||
MaxIdleConnsPerHost: -1,
|
MaxIdleConnsPerHost: -1,
|
||||||
}
|
}
|
||||||
if b, _ := base.(*http.Transport); b != nil {
|
if b, _ := base.(*http.Transport); b != nil {
|
||||||
|
tlsClientConfig := b.TLSClientConfig
|
||||||
|
if tlsClientConfig.NextProtos != nil {
|
||||||
|
tlsClientConfig = cloneTLSClientConfig(tlsClientConfig)
|
||||||
|
tlsClientConfig.NextProtos = nil
|
||||||
|
}
|
||||||
|
|
||||||
t.Proxy = b.Proxy
|
t.Proxy = b.Proxy
|
||||||
t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig)
|
t.TLSClientConfig = tlsClientConfig
|
||||||
t.TLSClientConfig.NextProtos = nil
|
|
||||||
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
|
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
|
||||||
t.Dial = b.Dial
|
t.Dial = b.Dial
|
||||||
t.DialTLS = b.DialTLS
|
t.DialTLS = b.DialTLS
|
||||||
|
@ -363,19 +368,15 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
||||||
|
|
||||||
dial := getTransportDial(t)
|
dial := getTransportDial(t)
|
||||||
dialTLS := getTransportDialTLS(t)
|
dialTLS := getTransportDialTLS(t)
|
||||||
|
|
||||||
t.Dial = func(network, addr string) (net.Conn, error) {
|
t.Dial = func(network, addr string) (net.Conn, error) {
|
||||||
c, err := dial(network, addr)
|
c, err := dial(network, addr)
|
||||||
hj.Conn = c
|
hj.Conn = c
|
||||||
return &hijackedConn{c, hj}, err
|
return &hijackedConn{c, hj}, err
|
||||||
}
|
}
|
||||||
|
t.DialTLS = func(network, addr string) (net.Conn, error) {
|
||||||
if dialTLS != nil {
|
c, err := dialTLS(network, addr)
|
||||||
t.DialTLS = func(network, addr string) (net.Conn, error) {
|
hj.Conn = c
|
||||||
c, err := dialTLS(network, addr)
|
return &hijackedConn{c, hj}, err
|
||||||
hj.Conn = c
|
|
||||||
return &hijackedConn{c, hj}, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return hj
|
return hj
|
||||||
|
@ -390,27 +391,35 @@ func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, e
|
||||||
return defaultDialer.Dial
|
return defaultDialer.Dial
|
||||||
}
|
}
|
||||||
|
|
||||||
// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil
|
// getTransportDial always returns a TLS Dialer
|
||||||
// and defaults to the existing t.DialTLS.
|
// and defaults to the existing t.DialTLS.
|
||||||
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
||||||
if t.DialTLS != nil {
|
if t.DialTLS != nil {
|
||||||
return t.DialTLS
|
return t.DialTLS
|
||||||
}
|
}
|
||||||
if t.TLSClientConfig == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// newConnHijackerTransport will modify t.Dial after calling this method
|
// newConnHijackerTransport will modify t.Dial after calling this method
|
||||||
// => Create a backup reference.
|
// => Create a backup reference.
|
||||||
plainDial := getTransportDial(t)
|
plainDial := getTransportDial(t)
|
||||||
|
|
||||||
|
// The following DialTLS implementation stems from the Go stdlib and
|
||||||
|
// is identical to what happens if DialTLS is not provided.
|
||||||
|
// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
|
||||||
return func(network, addr string) (net.Conn, error) {
|
return func(network, addr string) (net.Conn, error) {
|
||||||
plainConn, err := plainDial(network, addr)
|
plainConn, err := plainDial(network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConn := tls.Client(plainConn, t.TLSClientConfig)
|
tlsClientConfig := t.TLSClientConfig
|
||||||
|
if tlsClientConfig == nil {
|
||||||
|
tlsClientConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
|
||||||
|
tlsClientConfig.ServerName = stripPort(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn := tls.Client(plainConn, tlsClientConfig)
|
||||||
errc := make(chan error, 2)
|
errc := make(chan error, 2)
|
||||||
var timer *time.Timer
|
var timer *time.Timer
|
||||||
if d := t.TLSHandshakeTimeout; d != 0 {
|
if d := t.TLSHandshakeTimeout; d != 0 {
|
||||||
|
@ -429,16 +438,12 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn
|
||||||
plainConn.Close()
|
plainConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !t.TLSClientConfig.InsecureSkipVerify {
|
if !tlsClientConfig.InsecureSkipVerify {
|
||||||
serverName := t.TLSClientConfig.ServerName
|
hostname := tlsClientConfig.ServerName
|
||||||
if serverName == "" {
|
if hostname == "" {
|
||||||
serverName = addr
|
hostname = stripPort(addr)
|
||||||
idx := strings.LastIndex(serverName, ":")
|
|
||||||
if idx != -1 {
|
|
||||||
serverName = serverName[:idx]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err := tlsConn.VerifyHostname(serverName); err != nil {
|
if err := tlsConn.VerifyHostname(hostname); err != nil {
|
||||||
plainConn.Close()
|
plainConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -448,6 +453,22 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stripPort returns address without its port if it has one and
|
||||||
|
// works with IP addresses as well as hostnames formatted as host:port.
|
||||||
|
//
|
||||||
|
// IPv6 addresses (excluding the port) must be enclosed in
|
||||||
|
// square brackets similar to the requirements of Go's stdlib.
|
||||||
|
func stripPort(address string) string {
|
||||||
|
// Keep in mind that the address might be a IPv6 address
|
||||||
|
// and thus contain a colon, but not have a port.
|
||||||
|
portIdx := strings.LastIndex(address, ":")
|
||||||
|
ipv6Idx := strings.LastIndex(address, "]")
|
||||||
|
if portIdx > ipv6Idx {
|
||||||
|
address = address[:portIdx]
|
||||||
|
}
|
||||||
|
return address
|
||||||
|
}
|
||||||
|
|
||||||
type tlsHandshakeTimeoutError struct{}
|
type tlsHandshakeTimeoutError struct{}
|
||||||
|
|
||||||
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
|
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
|
||||||
|
|
Loading…
Reference in a new issue