mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-07 19:38:49 +03:00
Merge pull request #1309 from lhecker/master
Fixed #1292 and resulting issues from #1300
This commit is contained in:
commit
7cbbb01f94
3 changed files with 348 additions and 83 deletions
|
@ -247,12 +247,28 @@ func createUpstreamRequest(r *http.Request) *http.Request {
|
|||
outreq.URL.Opaque = outreq.URL.RawPath
|
||||
}
|
||||
|
||||
// We are modifying the same underlying map from req (shallow
|
||||
// copied above) so we only copy it if necessary.
|
||||
copiedHeaders := false
|
||||
|
||||
// Remove hop-by-hop headers listed in the "Connection" header.
|
||||
// See RFC 2616, section 14.10.
|
||||
if c := outreq.Header.Get("Connection"); c != "" {
|
||||
for _, f := range strings.Split(c, ",") {
|
||||
if f = strings.TrimSpace(f); f != "" {
|
||||
if !copiedHeaders {
|
||||
outreq.Header = make(http.Header)
|
||||
copyHeader(outreq.Header, r.Header)
|
||||
copiedHeaders = true
|
||||
}
|
||||
outreq.Header.Del(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove hop-by-hop headers to the backend. Especially
|
||||
// important is "Connection" because we want a persistent
|
||||
// connection, regardless of what the client sent to us. This
|
||||
// is modifying the same underlying map from r (shallow
|
||||
// copied above) so we only copy it if necessary.
|
||||
var copiedHeaders bool
|
||||
// connection, regardless of what the client sent to us.
|
||||
for _, h := range hopHeaders {
|
||||
if outreq.Header.Get(h) != "" {
|
||||
if !copiedHeaders {
|
||||
|
|
|
@ -42,10 +42,32 @@ func TestReverseProxy(t *testing.T) {
|
|||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
verifyHeaders := func(headers http.Header, trailers http.Header) {
|
||||
if headers.Get("X-Header") != "header-value" {
|
||||
t.Error("Expected header 'X-Header' to be proxied properly")
|
||||
}
|
||||
|
||||
if trailers == nil {
|
||||
t.Error("Expected to receive trailers")
|
||||
}
|
||||
if trailers.Get("X-Trailer") != "trailer-value" {
|
||||
t.Error("Expected header 'X-Trailer' to be proxied properly")
|
||||
}
|
||||
}
|
||||
|
||||
var requestReceived bool
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// read the body (even if it's empty) to make Go parse trailers
|
||||
io.Copy(ioutil.Discard, r.Body)
|
||||
verifyHeaders(r.Header, r.Trailer)
|
||||
|
||||
requestReceived = true
|
||||
|
||||
w.Header().Set("Trailer", "X-Trailer")
|
||||
w.Header().Set("X-Header", "header-value")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Hello, client"))
|
||||
w.Header().Set("X-Trailer", "trailer-value")
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
|
@ -59,12 +81,21 @@ func TestReverseProxy(t *testing.T) {
|
|||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.ContentLength = -1 // force chunked encoding (required for trailers)
|
||||
r.Header.Set("X-Header", "header-value")
|
||||
r.Trailer = map[string][]string{
|
||||
"X-Trailer": {"trailer-value"},
|
||||
}
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
if !requestReceived {
|
||||
t.Error("Expected backend to receive request, but it didn't")
|
||||
}
|
||||
|
||||
res := w.Result()
|
||||
verifyHeaders(res.Header, res.Trailer)
|
||||
|
||||
// Make sure {upstream} placeholder is set
|
||||
rr := httpserver.NewResponseRecorder(httptest.NewRecorder())
|
||||
rr.Replacer = httpserver.NewReplacer(r, rr, "-")
|
||||
|
@ -123,7 +154,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
|
|||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsNop.URL)
|
||||
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||
|
||||
// Create client request
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
@ -148,7 +179,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
|
|||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsNop.URL)
|
||||
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||
|
||||
// Create client request
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
@ -189,7 +220,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
|||
defer wsEcho.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsEcho.URL)
|
||||
p := newWebSocketTestProxy(wsEcho.URL, false)
|
||||
|
||||
// This is a full end-end test, so the proxy handler
|
||||
// has to be part of a server listening on a port. Our
|
||||
|
@ -228,6 +259,52 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestWebSocketReverseProxyFromWSSClient(t *testing.T) {
|
||||
wsEcho := newTLSServer(websocket.Handler(func(ws *websocket.Conn) {
|
||||
io.Copy(ws, ws)
|
||||
}))
|
||||
defer wsEcho.Close()
|
||||
|
||||
p := newWebSocketTestProxy(wsEcho.URL, true)
|
||||
|
||||
echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
}))
|
||||
defer echoProxy.Close()
|
||||
|
||||
// Set up WebSocket client
|
||||
url := strings.Replace(echoProxy.URL, "https://", "wss://", 1)
|
||||
wsCfg, err := websocket.NewConfig(url, echoProxy.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wsCfg.TlsConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
ws, err := websocket.DialConfig(wsCfg)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
// Send test message
|
||||
trialMsg := "Is it working?"
|
||||
|
||||
if sendErr := websocket.Message.Send(ws, trialMsg); sendErr != nil {
|
||||
t.Fatal(sendErr)
|
||||
}
|
||||
|
||||
// It should be echoed back to us
|
||||
var actualMsg string
|
||||
|
||||
if rcvErr := websocket.Message.Receive(ws, &actualMsg); rcvErr != nil {
|
||||
t.Fatal(rcvErr)
|
||||
}
|
||||
|
||||
if actualMsg != trialMsg {
|
||||
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixSocketProxy(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
return
|
||||
|
@ -264,7 +341,7 @@ func TestUnixSocketProxy(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
url := strings.Replace(ts.URL, "http://", "unix:", 1)
|
||||
p := newWebSocketTestProxy(url)
|
||||
p := newWebSocketTestProxy(url, false)
|
||||
|
||||
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
|
@ -982,10 +1059,14 @@ func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.
|
|||
// redirect to the specified backendAddr. The function
|
||||
// also sets up the rules/environment for testing WebSocket
|
||||
// proxy.
|
||||
func newWebSocketTestProxy(backendAddr string) *Proxy {
|
||||
func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
|
||||
return &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}},
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{&fakeWsUpstream{
|
||||
name: backendAddr,
|
||||
without: "",
|
||||
insecure: insecure,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -997,8 +1078,9 @@ func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
|
|||
}
|
||||
|
||||
type fakeWsUpstream struct {
|
||||
name string
|
||||
without string
|
||||
name string
|
||||
without string
|
||||
insecure bool
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) From() string {
|
||||
|
@ -1007,13 +1089,17 @@ func (u *fakeWsUpstream) From() string {
|
|||
|
||||
func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
|
||||
uri, _ := url.Parse(u.name)
|
||||
return &UpstreamHost{
|
||||
host := &UpstreamHost{
|
||||
Name: u.name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
|
||||
UpstreamHeaders: http.Header{
|
||||
"Connection": {"{>Connection}"},
|
||||
"Upgrade": {"{>Upgrade}"}},
|
||||
}
|
||||
if u.insecure {
|
||||
host.ReverseProxy.UseInsecureTransport()
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
|
||||
|
|
|
@ -27,10 +27,28 @@ import (
|
|||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
var bufferPool = sync.Pool{New: createBuffer}
|
||||
var (
|
||||
defaultDialer = &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
|
||||
bufferPool = sync.Pool{New: createBuffer}
|
||||
)
|
||||
|
||||
func createBuffer() interface{} {
|
||||
return make([]byte, 32*1024)
|
||||
return make([]byte, 0, 32*1024)
|
||||
}
|
||||
|
||||
func pooledIoCopy(dst io.Writer, src io.Reader) {
|
||||
buf := bufferPool.Get().([]byte)
|
||||
defer bufferPool.Put(buf)
|
||||
|
||||
// CopyBuffer only uses buf up to its length and panics if it's 0.
|
||||
// Due to that we extend buf's length to its capacity here and
|
||||
// ensure it's always non-zero.
|
||||
bufCap := cap(buf)
|
||||
io.CopyBuffer(dst, src, buf[0:bufCap:bufCap])
|
||||
}
|
||||
|
||||
// onExitFlushLoop is a callback set by tests to detect the state of the
|
||||
|
@ -135,11 +153,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
|||
// just use default transport, to avoid creating
|
||||
// a brand new transport
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: defaultDialer.Dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
@ -148,7 +163,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
|||
} else {
|
||||
transport.MaxIdleConnsPerHost = keepalive
|
||||
}
|
||||
http2.ConfigureTransport(transport)
|
||||
if httpserver.HTTP2 {
|
||||
http2.ConfigureTransport(transport)
|
||||
}
|
||||
rp.Transport = transport
|
||||
}
|
||||
return rp
|
||||
|
@ -160,18 +177,20 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
|||
func (rp *ReverseProxy) UseInsecureTransport() {
|
||||
if rp.Transport == nil {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: defaultDialer.Dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
http2.ConfigureTransport(transport)
|
||||
if httpserver.HTTP2 {
|
||||
http2.ConfigureTransport(transport)
|
||||
}
|
||||
rp.Transport = transport
|
||||
} else if transport, ok := rp.Transport.(*http.Transport); ok {
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
// No http2.ConfigureTransport() here.
|
||||
// For now this is only added in places where
|
||||
// an http.Transport is actually created.
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -186,20 +205,33 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
|||
}
|
||||
|
||||
rp.Director(outreq)
|
||||
outreq.Proto = "HTTP/1.1"
|
||||
outreq.ProtoMajor = 1
|
||||
outreq.ProtoMinor = 1
|
||||
outreq.Close = false
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isWebsocket := res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket"
|
||||
|
||||
// Remove hop-by-hop headers listed in the
|
||||
// "Connection" header of the response.
|
||||
if c := res.Header.Get("Connection"); c != "" {
|
||||
for _, f := range strings.Split(c, ",") {
|
||||
if f = strings.TrimSpace(f); f != "" {
|
||||
res.Header.Del(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, h := range hopHeaders {
|
||||
res.Header.Del(h)
|
||||
}
|
||||
|
||||
if respUpdateFn != nil {
|
||||
respUpdateFn(res)
|
||||
}
|
||||
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
|
||||
|
||||
if isWebsocket {
|
||||
res.Body.Close()
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
|
@ -228,27 +260,39 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
|||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
go func() {
|
||||
io.Copy(backendConn, conn) // write tcp stream to backend.
|
||||
}()
|
||||
io.Copy(conn, backendConn) // read tcp stream from backend.
|
||||
go pooledIoCopy(backendConn, conn) // write tcp stream to backend
|
||||
pooledIoCopy(conn, backendConn) // read tcp stream from backend
|
||||
} else {
|
||||
defer res.Body.Close()
|
||||
for _, h := range hopHeaders {
|
||||
res.Header.Del(h)
|
||||
}
|
||||
copyHeader(rw.Header(), res.Header)
|
||||
|
||||
// The "Trailer" header isn't included in the Transport's response,
|
||||
// at least for *http.Transport. Build it up from Trailer.
|
||||
if len(res.Trailer) > 0 {
|
||||
trailerKeys := make([]string, 0, len(res.Trailer))
|
||||
for k := range res.Trailer {
|
||||
trailerKeys = append(trailerKeys, k)
|
||||
}
|
||||
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
|
||||
}
|
||||
|
||||
rw.WriteHeader(res.StatusCode)
|
||||
if len(res.Trailer) > 0 {
|
||||
// Force chunking if we saw a response trailer.
|
||||
// This prevents net/http from calculating the length for short
|
||||
// bodies and adding a Content-Length.
|
||||
if fl, ok := rw.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
rp.copyResponse(rw, res.Body)
|
||||
res.Body.Close() // close now, instead of defer, to populate res.Trailer
|
||||
copyHeader(rw.Header(), res.Trailer)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
||||
buf := bufferPool.Get()
|
||||
defer bufferPool.Put(buf)
|
||||
|
||||
if rp.FlushInterval != 0 {
|
||||
if wf, ok := dst.(writeFlusher); ok {
|
||||
mlw := &maxLatencyWriter{
|
||||
|
@ -261,7 +305,7 @@ func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
|||
dst = mlw
|
||||
}
|
||||
}
|
||||
io.CopyBuffer(dst, src, buf.([]byte))
|
||||
pooledIoCopy(dst, src)
|
||||
}
|
||||
|
||||
// skip these headers if they already exist.
|
||||
|
@ -295,16 +339,17 @@ func copyHeader(dst, src http.Header) {
|
|||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||
var hopHeaders = []string{
|
||||
"Alt-Svc",
|
||||
"Alternate-Protocol",
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailers",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
"Alternate-Protocol",
|
||||
"Alt-Svc",
|
||||
}
|
||||
|
||||
type respUpdateFn func(resp *http.Response)
|
||||
|
@ -331,51 +376,169 @@ type connHijackerTransport struct {
|
|||
}
|
||||
|
||||
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
t := &http.Transport{
|
||||
MaxIdleConnsPerHost: -1,
|
||||
}
|
||||
if base != nil {
|
||||
if baseTransport, ok := base.(*http.Transport); ok {
|
||||
transport.Proxy = baseTransport.Proxy
|
||||
transport.TLSClientConfig = baseTransport.TLSClientConfig
|
||||
transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout
|
||||
transport.Dial = baseTransport.Dial
|
||||
transport.DialTLS = baseTransport.DialTLS
|
||||
transport.MaxIdleConnsPerHost = -1
|
||||
if b, _ := base.(*http.Transport); b != nil {
|
||||
tlsClientConfig := b.TLSClientConfig
|
||||
if tlsClientConfig.NextProtos != nil {
|
||||
tlsClientConfig = cloneTLSClientConfig(tlsClientConfig)
|
||||
tlsClientConfig.NextProtos = nil
|
||||
}
|
||||
|
||||
t.Proxy = b.Proxy
|
||||
t.TLSClientConfig = tlsClientConfig
|
||||
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
|
||||
t.Dial = b.Dial
|
||||
t.DialTLS = b.DialTLS
|
||||
} else {
|
||||
t.Proxy = http.ProxyFromEnvironment
|
||||
t.TLSHandshakeTimeout = 10 * time.Second
|
||||
}
|
||||
hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
|
||||
oldDial := transport.Dial
|
||||
oldDialTLS := transport.DialTLS
|
||||
if oldDial == nil {
|
||||
oldDial = (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial
|
||||
hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}
|
||||
|
||||
dial := getTransportDial(t)
|
||||
dialTLS := getTransportDialTLS(t)
|
||||
t.Dial = func(network, addr string) (net.Conn, error) {
|
||||
c, err := dial(network, addr)
|
||||
hj.Conn = c
|
||||
return &hijackedConn{c, hj}, err
|
||||
}
|
||||
hjTransport.Dial = func(network, addr string) (net.Conn, error) {
|
||||
c, err := oldDial(network, addr)
|
||||
hjTransport.Conn = c
|
||||
return &hijackedConn{c, hjTransport}, err
|
||||
t.DialTLS = func(network, addr string) (net.Conn, error) {
|
||||
c, err := dialTLS(network, addr)
|
||||
hj.Conn = c
|
||||
return &hijackedConn{c, hj}, err
|
||||
}
|
||||
if oldDialTLS != nil {
|
||||
hjTransport.DialTLS = func(network, addr string) (net.Conn, error) {
|
||||
c, err := oldDialTLS(network, addr)
|
||||
hjTransport.Conn = c
|
||||
return &hijackedConn{c, hjTransport}, err
|
||||
|
||||
return hj
|
||||
}
|
||||
|
||||
// getTransportDial always returns a plain Dialer
|
||||
// and defaults to the existing t.Dial.
|
||||
func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
||||
if t.Dial != nil {
|
||||
return t.Dial
|
||||
}
|
||||
return defaultDialer.Dial
|
||||
}
|
||||
|
||||
// getTransportDial always returns a TLS Dialer
|
||||
// and defaults to the existing t.DialTLS.
|
||||
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
||||
if t.DialTLS != nil {
|
||||
return t.DialTLS
|
||||
}
|
||||
|
||||
// newConnHijackerTransport will modify t.Dial after calling this method
|
||||
// => Create a backup reference.
|
||||
plainDial := getTransportDial(t)
|
||||
|
||||
// The following DialTLS implementation stems from the Go stdlib and
|
||||
// is identical to what happens if DialTLS is not provided.
|
||||
// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
|
||||
return func(network, addr string) (net.Conn, error) {
|
||||
plainConn, err := plainDial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsClientConfig := t.TLSClientConfig
|
||||
if tlsClientConfig == nil {
|
||||
tlsClientConfig = &tls.Config{}
|
||||
}
|
||||
if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
|
||||
tlsClientConfig.ServerName = stripPort(addr)
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(plainConn, tlsClientConfig)
|
||||
errc := make(chan error, 2)
|
||||
var timer *time.Timer
|
||||
if d := t.TLSHandshakeTimeout; d != 0 {
|
||||
timer = time.AfterFunc(d, func() {
|
||||
errc <- tlsHandshakeTimeoutError{}
|
||||
})
|
||||
}
|
||||
go func() {
|
||||
err := tlsConn.Handshake()
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
errc <- err
|
||||
}()
|
||||
if err := <-errc; err != nil {
|
||||
plainConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !tlsClientConfig.InsecureSkipVerify {
|
||||
hostname := tlsClientConfig.ServerName
|
||||
if hostname == "" {
|
||||
hostname = stripPort(addr)
|
||||
}
|
||||
if err := tlsConn.VerifyHostname(hostname); err != nil {
|
||||
plainConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
}
|
||||
|
||||
// stripPort returns address without its port if it has one and
|
||||
// works with IP addresses as well as hostnames formatted as host:port.
|
||||
//
|
||||
// IPv6 addresses (excluding the port) must be enclosed in
|
||||
// square brackets similar to the requirements of Go's stdlib.
|
||||
func stripPort(address string) string {
|
||||
// Keep in mind that the address might be a IPv6 address
|
||||
// and thus contain a colon, but not have a port.
|
||||
portIdx := strings.LastIndex(address, ":")
|
||||
ipv6Idx := strings.LastIndex(address, "]")
|
||||
if portIdx > ipv6Idx {
|
||||
address = address[:portIdx]
|
||||
}
|
||||
return address
|
||||
}
|
||||
|
||||
type tlsHandshakeTimeoutError struct{}
|
||||
|
||||
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
|
||||
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
|
||||
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
|
||||
|
||||
// cloneTLSClientConfig is like cloneTLSConfig but omits
|
||||
// the fields SessionTicketsDisabled and SessionTicketKey.
|
||||
// This makes it safe to call cloneTLSClientConfig on a config
|
||||
// in active use by a server.
|
||||
func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
|
||||
if cfg == nil {
|
||||
return &tls.Config{}
|
||||
}
|
||||
return &tls.Config{
|
||||
Rand: cfg.Rand,
|
||||
Time: cfg.Time,
|
||||
Certificates: cfg.Certificates,
|
||||
NameToCertificate: cfg.NameToCertificate,
|
||||
GetCertificate: cfg.GetCertificate,
|
||||
RootCAs: cfg.RootCAs,
|
||||
NextProtos: cfg.NextProtos,
|
||||
ServerName: cfg.ServerName,
|
||||
ClientAuth: cfg.ClientAuth,
|
||||
ClientCAs: cfg.ClientCAs,
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
CipherSuites: cfg.CipherSuites,
|
||||
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
|
||||
ClientSessionCache: cfg.ClientSessionCache,
|
||||
MinVersion: cfg.MinVersion,
|
||||
MaxVersion: cfg.MaxVersion,
|
||||
CurvePreferences: cfg.CurvePreferences,
|
||||
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
|
||||
Renegotiation: cfg.Renegotiation,
|
||||
}
|
||||
return hjTransport
|
||||
}
|
||||
|
||||
func requestIsWebsocket(req *http.Request) bool {
|
||||
return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"))
|
||||
return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
|
||||
}
|
||||
|
||||
type writeFlusher interface {
|
||||
|
|
Loading…
Reference in a new issue