From e377eeff50e6785605d15670eaf8740bad889e8b Mon Sep 17 00:00:00 2001
From: Tw <tw19881113@gmail.com>
Date: Sat, 23 Sep 2017 08:10:48 +0800
Subject: [PATCH] proxy: websocket proxy exits immediately if backend is
 shutdown (#1869)

Signed-off-by: Tw <tw19881113@gmail.com>
---
 caddyhttp/proxy/proxy_test.go   | 34 +++++++++++++++++++++++++++++++++
 caddyhttp/proxy/reverseproxy.go | 15 +++++++++++++--
 2 files changed, 47 insertions(+), 2 deletions(-)

diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go
index 3166e85f0..6412cc1bd 100644
--- a/caddyhttp/proxy/proxy_test.go
+++ b/caddyhttp/proxy/proxy_test.go
@@ -304,6 +304,40 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
 	p.ServeHTTP(nonHijacker, r)
 }
 
+func TestWebSocketReverseProxyBackendShutDown(t *testing.T) {
+	shutdown := make(chan struct{})
+	backend := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {
+		shutdown <- struct{}{}
+	}))
+	defer backend.Close()
+
+	go func() {
+		<-shutdown
+		backend.Close()
+	}()
+
+	// Get proxy to use for the test
+	p := newWebSocketTestProxy(backend.URL, false)
+	backendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		p.ServeHTTP(w, r)
+	}))
+	defer backendProxy.Close()
+
+	// Set up WebSocket client
+	url := strings.Replace(backendProxy.URL, "http://", "ws://", 1)
+	ws, err := websocket.Dial(url, "", backendProxy.URL)
+
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer ws.Close()
+
+	var actualMsg string
+	if rcvErr := websocket.Message.Receive(ws, &actualMsg); rcvErr == nil {
+		t.Errorf("we don't get backend shutdown notification")
+	}
+}
+
 func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
 	// No-op websocket backend simply allows the WS connection to be
 	// accepted then it will be immediately closed. Perfect for testing.
diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go
index 496607490..c22cf1017 100644
--- a/caddyhttp/proxy/reverseproxy.go
+++ b/caddyhttp/proxy/reverseproxy.go
@@ -320,8 +320,13 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
 		}
 		defer backendConn.Close()
 
+		proxyDone := make(chan struct{}, 2)
+
 		// Proxy backend -> frontend.
-		go pooledIoCopy(conn, backendConn)
+		go func() {
+			pooledIoCopy(conn, backendConn)
+			proxyDone <- struct{}{}
+		}()
 
 		// Proxy frontend -> backend.
 		//
@@ -336,7 +341,13 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
 				backendConn.Write(rbuf)
 			}
 		}
-		pooledIoCopy(backendConn, conn)
+		go func() {
+			pooledIoCopy(backendConn, conn)
+			proxyDone <- struct{}{}
+		}()
+
+		// If one side is done, we are done.
+		<-proxyDone
 	} else {
 		// NOTE:
 		//   Closing the Body involves acquiring a mutex, which is a