diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index 1db1131e..f3a0390b 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -16,7 +16,7 @@ import ( "time" ) -const HTTPSwitchProtocols = 101 +const HTTPSwitchingProtocols = 101 // onExitFlushLoop is a callback set by tests to detect the state of the // flushLoop() goroutine. @@ -155,7 +155,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extr copyHeader(rw.Header(), res.Header) - if res.StatusCode == HTTPSwitchProtocols { + if res.StatusCode == HTTPSwitchingProtocols && outreq.Header.Get("Upgrade") == "websocket" { hj, ok := rw.(http.Hijacker) if !ok { return nil diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index a657a088..4c1b9fff 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -12,7 +12,10 @@ import ( "github.com/mholt/caddy/config/parse" ) -var supportedPolicies map[string]func() Policy = make(map[string]func() Policy) +var ( + supportedPolicies map[string]func() Policy = make(map[string]func() Policy) + proxyHeaders http.Header +) type staticUpstream struct { from string @@ -40,7 +43,7 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { FailTimeout: 10 * time.Second, MaxFails: 1, } - var proxyHeaders http.Header + if !c.Args(&upstream.from) { return upstreams, c.ArgErr() } @@ -97,10 +100,10 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { if !c.Args(&header, &value) { return upstreams, c.ArgErr() } - if proxyHeaders == nil { - proxyHeaders = make(map[string][]string) - } - proxyHeaders.Add(header, value) + addProxyHeader(header, value) + case "websocket": + addProxyHeader("Connection", "{>Connection}") + addProxyHeader("Upgrade", "{>Upgrade}") } } @@ -150,6 +153,14 @@ func RegisterPolicy(name string, policy func() Policy) { supportedPolicies[name] = policy } +// AddProxyHeader adds a proxy header. +func addProxyHeader(header, value string) { + if proxyHeaders == nil { + proxyHeaders = make(map[string][]string) + } + proxyHeaders.Add(header, value) +} + func (u *staticUpstream) From() string { return u.from }