diff --git a/modules/caddyhttp/headers/headers.go b/modules/caddyhttp/headers/headers.go index 71e575dd..813b9fec 100644 --- a/modules/caddyhttp/headers/headers.go +++ b/modules/caddyhttp/headers/headers.go @@ -79,15 +79,8 @@ func (h Handler) Validate() error { func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) - h.Request.applyTo(r.Header, repl) - - // request header's Host is handled specially by the - // Go standard library, so if that header was changed, - // change it in the Host field since the Header won't - // be used - if intendedHost := r.Header.Get("Host"); intendedHost != "" { - r.Host = intendedHost - r.Header.Del("Host") + if h.Request != nil { + h.Request.ApplyToRequest(r) } if h.Response != nil { @@ -99,7 +92,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhtt headerOps: h.Response.HeaderOps, } } else { - h.Response.applyTo(w.Header(), repl) + h.Response.ApplyTo(w.Header(), repl) } } @@ -160,11 +153,9 @@ type RespHeaderOps struct { Deferred bool `json:"deferred,omitempty"` } -func (ops *HeaderOps) applyTo(hdr http.Header, repl caddy.Replacer) { - if ops == nil { - return - } - +// ApplyTo applies ops to hdr using repl. +func (ops HeaderOps) ApplyTo(hdr http.Header, repl caddy.Replacer) { + // add for fieldName, vals := range ops.Add { fieldName = repl.ReplaceAll(fieldName, "") for _, v := range vals { @@ -172,22 +163,28 @@ func (ops *HeaderOps) applyTo(hdr http.Header, repl caddy.Replacer) { } } + // set for fieldName, vals := range ops.Set { fieldName = repl.ReplaceAll(fieldName, "") + var newVals []string for i := range vals { - vals[i] = repl.ReplaceAll(vals[i], "") + // append to new slice so we don't overwrite + // the original values in ops.Set + newVals = append(newVals, repl.ReplaceAll(vals[i], "")) } - hdr.Set(fieldName, strings.Join(vals, ",")) + hdr.Set(fieldName, strings.Join(newVals, ",")) } + // delete for _, fieldName := range ops.Delete { hdr.Del(repl.ReplaceAll(fieldName, "")) } + // replace for fieldName, replacements := range ops.Replace { fieldName = repl.ReplaceAll(fieldName, "") - // perform replacements across all fields + // all fields... if fieldName == "*" { for _, r := range replacements { search := repl.ReplaceAll(r.Search, "") @@ -205,7 +202,7 @@ func (ops *HeaderOps) applyTo(hdr http.Header, repl caddy.Replacer) { continue } - // perform replacements only with the named field + // ...or only with the named field for _, r := range replacements { search := repl.ReplaceAll(r.Search, "") replace := repl.ReplaceAll(r.Replace, "") @@ -220,6 +217,42 @@ func (ops *HeaderOps) applyTo(hdr http.Header, repl caddy.Replacer) { } } +// ApplyToRequest applies ops to r, specially handling the Host +// header which the standard library does not include with the +// header map with all the others. This method mutates r.Host. +func (ops HeaderOps) ApplyToRequest(r *http.Request) { + repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer) + + // capture the current Host header so we can + // reset to it when we're done + origHost, hadHost := r.Header["Host"] + + // append r.Host; this way, we know that our value + // was last in the list, and if an Add operation + // appended something else after it, that's probably + // fine because it's weird to have multiple Host + // headers anyway and presumably the one they added + // is the one they wanted + r.Header["Host"] = append(r.Header["Host"], r.Host) + + // apply header operations + ops.ApplyTo(r.Header, repl) + + // retrieve the last Host value (likely the one we appended) + if len(r.Header["Host"]) > 0 { + r.Host = r.Header["Host"][len(r.Header["Host"])-1] + } else { + r.Host = "" + } + + // reset the Host header slice + if hadHost { + r.Header["Host"] = origHost + } else { + delete(r.Header, "Host") + } +} + // responseWriterWrapper defers response header // operations until WriteHeader is called. type responseWriterWrapper struct { @@ -236,7 +269,9 @@ func (rww *responseWriterWrapper) WriteHeader(status int) { } rww.wroteHeader = true if rww.require == nil || rww.require.Match(status, rww.ResponseWriterWrapper.Header()) { - rww.headerOps.applyTo(rww.ResponseWriterWrapper.Header(), rww.replacer) + if rww.headerOps != nil { + rww.headerOps.ApplyTo(rww.ResponseWriterWrapper.Header(), rww.replacer) + } } rww.ResponseWriterWrapper.WriteHeader(status) }