proxy: Add, remove, or replace upstream and downstream headers (closes #666) (PR #788)

* Overwrite proxy headers based on directive

Headers of the request sent by the proxy upstream can now be modified in
the following way:

Prefix header with `+`: Header will be added if it doesn't exist
otherwise, the values will be merge
Prefix header with `-': Header will be removed
No prefix: Header will be replaced with given value

* Add missing formating directive reported by go vet

* Overwrite up/down stream proxy headers

Add Up/DownStreamHeaders to UpstreamHost

Split `proxy_header` option in `proxy` directive into `header_upstream`
and `header_downstream`. By splitting into two, it makes it clear in
what direction the given headers must be applied.

`proxy_header` can still be used (to maintain backward compatability)
but its assumed to be `header_upstream`

Response headers received by the reverse proxy from the upstream host
are updated according the `header_downstream` rules.
The update occurs through a func given to the reverse proxy, which is
applied once a response is received.

Headers (for upstream and downstream) can now be modified in
the following way:

Prefix header with `+`: Header will be added if it doesn't exist
otherwise, the values will be merge
Prefix header with `-': Header will be removed
No prefix: Header will be replaced with given value

Updated branch with changes from master

* minor refactor to make intent clearer

* Make Up/Down stream headers naming consistent

* Fix error descriptions to be more clear

* Fix lint issue
This commit is contained in:
William Bezuidenhout 2016-04-30 21:41:30 +02:00 committed by Matt Holt
parent 96425f0f40
commit e2234497b7
4 changed files with 234 additions and 31 deletions

View file

@ -43,7 +43,8 @@ type UpstreamHost struct {
Fails int32 Fails int32
FailTimeout time.Duration FailTimeout time.Duration
Unhealthy bool Unhealthy bool
ExtraHeaders http.Header UpstreamHeaders http.Header
DownstreamHeaders http.Header
CheckDown UpstreamHostDownFunc CheckDown UpstreamHostDownFunc
WithoutPathPrefix string WithoutPathPrefix string
MaxConns int64 MaxConns int64
@ -99,26 +100,33 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
} }
outreq.Host = host.Name outreq.Host = host.Name
if host.ExtraHeaders != nil { if host.UpstreamHeaders != nil {
extraHeaders := make(http.Header)
if replacer == nil { if replacer == nil {
rHost := r.Host rHost := r.Host
replacer = middleware.NewReplacer(r, nil, "") replacer = middleware.NewReplacer(r, nil, "")
outreq.Host = rHost outreq.Host = rHost
} }
for header, values := range host.ExtraHeaders { if v, ok := host.UpstreamHeaders["Host"]; ok {
for _, value := range values { r.Host = replacer.Replace(v[len(v)-1])
extraHeaders.Add(header, replacer.Replace(value))
if header == "Host" {
outreq.Host = replacer.Replace(value)
} }
} // Modify headers for request that will be sent to the upstream host
} upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer)
for k, v := range extraHeaders { for k, v := range upHeaders {
outreq.Header[k] = v outreq.Header[k] = v
} }
} }
var downHeaderUpdateFn respUpdateFn
if host.DownstreamHeaders != nil {
if replacer == nil {
rHost := r.Host
replacer = middleware.NewReplacer(r, nil, "")
outreq.Host = rHost
}
//Creates a function that is used to update headers the response received by the reverse proxy
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
}
proxy := host.ReverseProxy proxy := host.ReverseProxy
if baseURL, err := url.Parse(host.Name); err == nil { if baseURL, err := url.Parse(host.Name); err == nil {
r.Host = baseURL.Host r.Host = baseURL.Host
@ -130,7 +138,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
} }
atomic.AddInt64(&host.Conns, 1) atomic.AddInt64(&host.Conns, 1)
backendErr := proxy.ServeHTTP(w, outreq) backendErr := proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
atomic.AddInt64(&host.Conns, -1) atomic.AddInt64(&host.Conns, -1)
if backendErr == nil { if backendErr == nil {
return 0, nil return 0, nil
@ -182,3 +190,48 @@ func createUpstreamRequest(r *http.Request) *http.Request {
return outreq return outreq
} }
func createRespHeaderUpdateFn(rules http.Header, replacer middleware.Replacer) respUpdateFn {
return func(resp *http.Response) {
newHeaders := createHeadersByRules(rules, resp.Header, replacer)
for h, v := range newHeaders {
resp.Header[h] = v
}
}
}
func createHeadersByRules(rules http.Header, base http.Header, repl middleware.Replacer) http.Header {
newHeaders := make(http.Header)
for header, values := range rules {
if strings.HasPrefix(header, "+") {
header = strings.TrimLeft(header, "+")
add(newHeaders, header, base[header])
applyEach(values, repl.Replace)
add(newHeaders, header, values)
} else if strings.HasPrefix(header, "-") {
base.Del(strings.TrimLeft(header, "-"))
} else if _, ok := base[header]; ok {
applyEach(values, repl.Replace)
for _, v := range values {
newHeaders.Set(header, v)
}
} else {
applyEach(values, repl.Replace)
add(newHeaders, header, values)
add(newHeaders, header, base[header])
}
}
return newHeaders
}
func applyEach(values []string, mapFn func(string) string) {
for i, v := range values {
values[i] = mapFn(v)
}
}
func add(base http.Header, header string, values []string) {
for _, v := range values {
base.Add(header, v)
}
}

View file

@ -348,6 +348,141 @@ func TestUnixSocketProxyPaths(t *testing.T) {
} }
} }
func TestUpstreamHeadersUpdate(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
var actualHeaders http.Header
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, client"))
actualHeaders = r.Header
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream.host.UpstreamHeaders = http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"},
"+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"},
"-Remove-Me": {""},
"Replace-Me": {"{hostname}"},
}
// set up proxy
p := &Proxy{
Upstreams: []Upstream{upstream},
}
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
w := httptest.NewRecorder()
//add initial headers
r.Header.Add("Merge-Me", "Initial")
r.Header.Add("Remove-Me", "Remove-Value")
r.Header.Add("Replace-Me", "Replace-Value")
p.ServeHTTP(w, r)
replacer := middleware.NewReplacer(r, nil, "")
headerKey := "Merge-Me"
values, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Request sent to upstream backend does not contain expected %v header. Expected header to be added", headerKey)
} else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) {
t.Errorf("Values for proxy header `+Merge-Me` should be merged. Got %v", values)
}
headerKey = "Add-Me"
if _, ok := actualHeaders[headerKey]; !ok {
t.Errorf("Request sent to upstream backend does not contain expected %v header", headerKey)
}
headerKey = "Remove-Me"
if _, ok := actualHeaders[headerKey]; ok {
t.Errorf("Request sent to upstream backend should not contain %v header", headerKey)
}
headerKey = "Replace-Me"
headerValue := replacer.Replace("{hostname}")
value, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Request sent to upstream backend should not remove %v header", headerKey)
} else if len(value) > 0 && headerValue != value[0] {
t.Errorf("Request sent to upstream backend should replace value of %v header with %v. Instead value was %v", headerKey, headerValue, value)
}
}
func TestDownstreamHeadersUpdate(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Merge-Me", "Initial")
w.Header().Add("Remove-Me", "Remove-Value")
w.Header().Add("Replace-Me", "Replace-Value")
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream.host.DownstreamHeaders = http.Header{
"+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"},
"-Remove-Me": {""},
"Replace-Me": {"{hostname}"},
}
// set up proxy
p := &Proxy{
Upstreams: []Upstream{upstream},
}
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
replacer := middleware.NewReplacer(r, nil, "")
actualHeaders := w.Header()
headerKey := "Merge-Me"
values, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Downstream response does not contain expected %v header. Expected header should be added", headerKey)
} else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) {
t.Errorf("Values for header `+Merge-Me` should be merged. Got %v", values)
}
headerKey = "Add-Me"
if _, ok := actualHeaders[headerKey]; !ok {
t.Errorf("Downstream response does not contain expected %v header", headerKey)
}
headerKey = "Remove-Me"
if _, ok := actualHeaders[headerKey]; ok {
t.Errorf("Downstream response should not contain %v header received from upstream", headerKey)
}
headerKey = "Replace-Me"
headerValue := replacer.Replace("{hostname}")
value, ok := actualHeaders[headerKey]
if !ok {
t.Errorf("Downstream response should contain %v header and not remove it", headerKey)
} else if len(value) > 0 && headerValue != value[0] {
t.Errorf("Downstream response should have header %v with value %v. Instead value was %v", headerKey, headerValue, value)
}
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream { func newFakeUpstream(name string, insecure bool) *fakeUpstream {
uri, _ := url.Parse(name) uri, _ := url.Parse(name)
u := &fakeUpstream{ u := &fakeUpstream{
@ -410,7 +545,7 @@ func (u *fakeWsUpstream) Select() *UpstreamHost {
return &UpstreamHost{ return &UpstreamHost{
Name: u.name, Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without), ReverseProxy: NewSingleHostReverseProxy(uri, u.without),
ExtraHeaders: http.Header{ UpstreamHeaders: http.Header{
"Connection": {"{>Connection}"}, "Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}}, "Upgrade": {"{>Upgrade}"}},
} }

View file

@ -154,7 +154,9 @@ var InsecureTransport http.RoundTripper = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
} }
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request) error { type respUpdateFn func(resp *http.Response)
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
transport := p.Transport transport := p.Transport
if transport == nil { if transport == nil {
transport = http.DefaultTransport transport = http.DefaultTransport
@ -169,6 +171,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request) e
res, err := transport.RoundTrip(outreq) res, err := transport.RoundTrip(outreq)
if err != nil { if err != nil {
return err return err
} else if respUpdateFn != nil {
respUpdateFn(res)
} }
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" { if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {

View file

@ -20,7 +20,8 @@ var (
type staticUpstream struct { type staticUpstream struct {
from string from string
proxyHeaders http.Header upstreamHeaders http.Header
downstreamHeaders http.Header
Hosts HostPool Hosts HostPool
Policy Policy Policy Policy
insecureSkipVerify bool insecureSkipVerify bool
@ -43,7 +44,8 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
for c.Next() { for c.Next() {
upstream := &staticUpstream{ upstream := &staticUpstream{
from: "", from: "",
proxyHeaders: make(http.Header), upstreamHeaders: make(http.Header),
downstreamHeaders: make(http.Header),
Hosts: nil, Hosts: nil,
Policy: &Random{}, Policy: &Random{},
FailTimeout: 10 * time.Second, FailTimeout: 10 * time.Second,
@ -102,7 +104,8 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
Fails: 0, Fails: 0,
FailTimeout: u.FailTimeout, FailTimeout: u.FailTimeout,
Unhealthy: false, Unhealthy: false,
ExtraHeaders: u.proxyHeaders, UpstreamHeaders: u.upstreamHeaders,
DownstreamHeaders: u.downstreamHeaders,
CheckDown: func(u *staticUpstream) UpstreamHostDownFunc { CheckDown: func(u *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool { return func(uh *UpstreamHost) bool {
if uh.Unhealthy { if uh.Unhealthy {
@ -182,15 +185,23 @@ func parseBlock(c *parse.Dispenser, u *staticUpstream) error {
} }
u.HealthCheck.Interval = dur u.HealthCheck.Interval = dur
} }
case "header_upstream":
fallthrough
case "proxy_header": case "proxy_header":
var header, value string var header, value string
if !c.Args(&header, &value) { if !c.Args(&header, &value) {
return c.ArgErr() return c.ArgErr()
} }
u.proxyHeaders.Add(header, value) u.upstreamHeaders.Add(header, value)
case "header_downstream":
var header, value string
if !c.Args(&header, &value) {
return c.ArgErr()
}
u.downstreamHeaders.Add(header, value)
case "websocket": case "websocket":
u.proxyHeaders.Add("Connection", "{>Connection}") u.upstreamHeaders.Add("Connection", "{>Connection}")
u.proxyHeaders.Add("Upgrade", "{>Upgrade}") u.upstreamHeaders.Add("Upgrade", "{>Upgrade}")
case "without": case "without":
if !c.NextArg() { if !c.NextArg() {
return c.ArgErr() return c.ArgErr()