diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index aab836bc..55abe2dc 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -98,11 +98,6 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) { // also sets up the rules/environment for testing WebSocket // proxy. func newWebSocketTestProxy(backendAddr string) *Proxy { - proxyHeaders = http.Header{ - "Connection": {"{>Connection}"}, - "Upgrade": {"{>Upgrade}"}, - } - return &Proxy{ Upstreams: []Upstream{&fakeUpstream{name: backendAddr}}, } @@ -121,7 +116,9 @@ func (u *fakeUpstream) Select() *UpstreamHost { return &UpstreamHost{ Name: u.name, ReverseProxy: NewSingleHostReverseProxy(uri, ""), - ExtraHeaders: proxyHeaders, + ExtraHeaders: http.Header{ + "Connection": {"{>Connection}"}, + "Upgrade": {"{>Upgrade}"}}, } } diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index 7916b3b6..5e9152e4 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -16,13 +16,13 @@ import ( var ( supportedPolicies = make(map[string]func() Policy) - proxyHeaders = make(http.Header) ) type staticUpstream struct { - from string - Hosts HostPool - Policy Policy + from string + proxyHeaders http.Header + Hosts HostPool + Policy Policy FailTimeout time.Duration MaxFails int32 @@ -72,7 +72,7 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { Fails: 0, FailTimeout: upstream.FailTimeout, Unhealthy: false, - ExtraHeaders: proxyHeaders, + ExtraHeaders: upstream.proxyHeaders, CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc { return func(uh *UpstreamHost) bool { if uh.Unhealthy { @@ -159,10 +159,16 @@ func parseBlock(c *parse.Dispenser, u *staticUpstream) error { if !c.Args(&header, &value) { return c.ArgErr() } - proxyHeaders.Add(header, value) + if u.proxyHeaders == nil { + u.proxyHeaders = make(http.Header) + } + u.proxyHeaders.Add(header, value) case "websocket": - proxyHeaders.Add("Connection", "{>Connection}") - proxyHeaders.Add("Upgrade", "{>Upgrade}") + if u.proxyHeaders == nil { + u.proxyHeaders = make(http.Header) + } + u.proxyHeaders.Add("Connection", "{>Connection}") + u.proxyHeaders.Add("Upgrade", "{>Upgrade}") case "without": if !c.NextArg() { return c.ArgErr()