mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-08 11:58:49 +03:00
Adjust proxy headers properly (fixes #916)
This commit is contained in:
parent
57710e8b0d
commit
6490ff6224
2 changed files with 25 additions and 52 deletions
|
@ -84,7 +84,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// this replacer is used to fill in header field values
|
// this replacer is used to fill in header field values
|
||||||
var replacer httpserver.Replacer
|
replacer := httpserver.NewReplacer(r, nil, "")
|
||||||
|
|
||||||
// outreq is the request that makes a roundtrip to the backend
|
// outreq is the request that makes a roundtrip to the backend
|
||||||
outreq := createUpstreamRequest(r)
|
outreq := createUpstreamRequest(r)
|
||||||
|
@ -119,16 +119,10 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
|
|
||||||
// set headers for request going upstream
|
// set headers for request going upstream
|
||||||
if host.UpstreamHeaders != nil {
|
if host.UpstreamHeaders != nil {
|
||||||
if replacer == nil {
|
|
||||||
replacer = httpserver.NewReplacer(r, nil, "")
|
|
||||||
}
|
|
||||||
if v, ok := host.UpstreamHeaders["Host"]; ok {
|
|
||||||
outreq.Host = replacer.Replace(v[len(v)-1])
|
|
||||||
}
|
|
||||||
// modify headers for request that will be sent to the upstream host
|
// modify headers for request that will be sent to the upstream host
|
||||||
upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer)
|
mutateHeadersByRules(outreq.Header, host.UpstreamHeaders, replacer)
|
||||||
for k, v := range upHeaders {
|
if hostHeaders, ok := outreq.Header["Host"]; ok && len(hostHeaders) > 0 {
|
||||||
outreq.Header[k] = v
|
outreq.Host = hostHeaders[len(hostHeaders)-1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,9 +130,6 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
// headers coming back downstream
|
// headers coming back downstream
|
||||||
var downHeaderUpdateFn respUpdateFn
|
var downHeaderUpdateFn respUpdateFn
|
||||||
if host.DownstreamHeaders != nil {
|
if host.DownstreamHeaders != nil {
|
||||||
if replacer == nil {
|
|
||||||
replacer = httpserver.NewReplacer(r, nil, "")
|
|
||||||
}
|
|
||||||
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
|
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,6 +176,8 @@ func (p Proxy) match(r *http.Request) Upstream {
|
||||||
|
|
||||||
// createUpstremRequest shallow-copies r into a new request
|
// createUpstremRequest shallow-copies r into a new request
|
||||||
// that can be sent upstream.
|
// that can be sent upstream.
|
||||||
|
//
|
||||||
|
// Derived from reverseproxy.go in the standard Go httputil package.
|
||||||
func createUpstreamRequest(r *http.Request) *http.Request {
|
func createUpstreamRequest(r *http.Request) *http.Request {
|
||||||
outreq := new(http.Request)
|
outreq := new(http.Request)
|
||||||
*outreq = *r // includes shallow copies of maps, but okay
|
*outreq = *r // includes shallow copies of maps, but okay
|
||||||
|
@ -199,10 +192,14 @@ func createUpstreamRequest(r *http.Request) *http.Request {
|
||||||
// connection, regardless of what the client sent to us. This
|
// connection, regardless of what the client sent to us. This
|
||||||
// is modifying the same underlying map from r (shallow
|
// is modifying the same underlying map from r (shallow
|
||||||
// copied above) so we only copy it if necessary.
|
// copied above) so we only copy it if necessary.
|
||||||
|
var copiedHeaders bool
|
||||||
for _, h := range hopHeaders {
|
for _, h := range hopHeaders {
|
||||||
if outreq.Header.Get(h) != "" {
|
if outreq.Header.Get(h) != "" {
|
||||||
outreq.Header = make(http.Header)
|
if !copiedHeaders {
|
||||||
copyHeader(outreq.Header, r.Header)
|
outreq.Header = make(http.Header)
|
||||||
|
copyHeader(outreq.Header, r.Header)
|
||||||
|
copiedHeaders = true
|
||||||
|
}
|
||||||
outreq.Header.Del(h)
|
outreq.Header.Del(h)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -222,45 +219,20 @@ func createUpstreamRequest(r *http.Request) *http.Request {
|
||||||
|
|
||||||
func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn {
|
func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn {
|
||||||
return func(resp *http.Response) {
|
return func(resp *http.Response) {
|
||||||
newHeaders := createHeadersByRules(rules, resp.Header, replacer)
|
mutateHeadersByRules(resp.Header, rules, replacer)
|
||||||
for h, v := range newHeaders {
|
|
||||||
resp.Header[h] = v
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createHeadersByRules(rules http.Header, base http.Header, repl httpserver.Replacer) http.Header {
|
func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer) {
|
||||||
newHeaders := make(http.Header)
|
for ruleField, ruleValues := range rules {
|
||||||
for header, values := range rules {
|
if strings.HasPrefix(ruleField, "+") {
|
||||||
if strings.HasPrefix(header, "+") {
|
for _, ruleValue := range ruleValues {
|
||||||
header = strings.TrimLeft(header, "+")
|
headers.Add(strings.TrimPrefix(ruleField, "+"), repl.Replace(ruleValue))
|
||||||
add(newHeaders, header, base[header])
|
|
||||||
applyEach(values, repl.Replace)
|
|
||||||
add(newHeaders, header, values)
|
|
||||||
} else if strings.HasPrefix(header, "-") {
|
|
||||||
base.Del(strings.TrimLeft(header, "-"))
|
|
||||||
} else if _, ok := base[header]; ok {
|
|
||||||
applyEach(values, repl.Replace)
|
|
||||||
for _, v := range values {
|
|
||||||
newHeaders.Set(header, v)
|
|
||||||
}
|
}
|
||||||
} else {
|
} else if strings.HasPrefix(ruleField, "-") {
|
||||||
applyEach(values, repl.Replace)
|
headers.Del(strings.TrimPrefix(ruleField, "-"))
|
||||||
add(newHeaders, header, values)
|
} else if len(ruleValues) > 0 {
|
||||||
add(newHeaders, header, base[header])
|
headers.Set(ruleField, repl.Replace(ruleValues[len(ruleValues)-1]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return newHeaders
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyEach(values []string, mapFn func(string) string) {
|
|
||||||
for i, v := range values {
|
|
||||||
values[i] = mapFn(v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func add(base http.Header, header string, values []string) {
|
|
||||||
for _, v := range values {
|
|
||||||
base.Add(header, v)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -177,10 +177,11 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, r
|
||||||
res, err := transport.RoundTrip(outreq)
|
res, err := transport.RoundTrip(outreq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if respUpdateFn != nil {
|
|
||||||
respUpdateFn(res)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if respUpdateFn != nil {
|
||||||
|
respUpdateFn(res)
|
||||||
|
}
|
||||||
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
|
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
hj, ok := rw.(http.Hijacker)
|
hj, ok := rw.(http.Hijacker)
|
||||||
|
|
Loading…
Reference in a new issue