diff --git a/caddyhttp/header/header.go b/caddyhttp/header/header.go index 08ba4bd2..121264d4 100644 --- a/caddyhttp/header/header.go +++ b/caddyhttp/header/header.go @@ -21,22 +21,23 @@ type Headers struct { // setting headers on the response according to the configured rules. func (h Headers) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { replacer := httpserver.NewReplacer(r, nil, "") + rww := &responseWriterWrapper{w: w} for _, rule := range h.Rules { if httpserver.Path(r.URL.Path).Matches(rule.Path) { for _, header := range rule.Headers { // One can either delete a header, add multiple values to a header, or simply // set a header. if strings.HasPrefix(header.Name, "-") { - w.Header().Del(strings.TrimLeft(header.Name, "-")) + rww.delHeader(strings.TrimLeft(header.Name, "-")) } else if strings.HasPrefix(header.Name, "+") { - w.Header().Add(strings.TrimLeft(header.Name, "+"), replacer.Replace(header.Value)) + rww.addHeader(strings.TrimLeft(header.Name, "+"), replacer.Replace(header.Value)) } else { - w.Header().Set(header.Name, replacer.Replace(header.Value)) + rww.setHeader(header.Name, replacer.Replace(header.Value)) } } } } - return h.Next.ServeHTTP(w, r) + return h.Next.ServeHTTP(rww, r) } type ( @@ -53,3 +54,62 @@ type ( Value string } ) + +// headerOperation represents an operation on the header +type headerOperation func(http.Header) + +// responseWriterWrapper wraps the real ResponseWriter. +// It defers header operations until writeHeader +type responseWriterWrapper struct { + w http.ResponseWriter + ops []headerOperation + wroteHeader bool +} + +func (rww *responseWriterWrapper) Header() http.Header { + return rww.w.Header() +} + +func (rww *responseWriterWrapper) Write(d []byte) (int, error) { + if !rww.wroteHeader { + rww.WriteHeader(http.StatusOK) + } + return rww.w.Write(d) +} + +func (rww *responseWriterWrapper) WriteHeader(status int) { + if rww.wroteHeader { + return + } + rww.wroteHeader = true + // capture the original headers + h := rww.Header() + + // perform our revisions + for _, op := range rww.ops { + op(h) + } + + rww.w.WriteHeader(status) +} + +// addHeader registers a http.Header.Add operation +func (rww *responseWriterWrapper) addHeader(key, value string) { + rww.ops = append(rww.ops, func(h http.Header) { + h.Add(key, value) + }) +} + +// delHeader registers a http.Header.Del operation +func (rww *responseWriterWrapper) delHeader(key string) { + rww.ops = append(rww.ops, func(h http.Header) { + h.Del(key) + }) +} + +// setHeader registers a http.Header.Set operation +func (rww *responseWriterWrapper) setHeader(key, value string) { + rww.ops = append(rww.ops, func(h http.Header) { + h.Set(key, value) + }) +} diff --git a/caddyhttp/header/header_test.go b/caddyhttp/header/header_test.go index 787c4d7a..0e0aaa31 100644 --- a/caddyhttp/header/header_test.go +++ b/caddyhttp/header/header_test.go @@ -1,6 +1,7 @@ package header import ( + "fmt" "net/http" "net/http/httptest" "os" @@ -30,6 +31,8 @@ func TestHeader(t *testing.T) { } { he := Headers{ Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + w.Header().Set("Bar", "Removed in /a") + fmt.Fprint(w, "This is a test") return 0, nil }), Rules: []Rule{ @@ -47,7 +50,6 @@ func TestHeader(t *testing.T) { } rec := httptest.NewRecorder() - rec.Header().Set("Bar", "Removed in /a") he.ServeHTTP(rec, req) @@ -61,6 +63,7 @@ func TestHeader(t *testing.T) { func TestMultipleHeaders(t *testing.T) { he := Headers{ Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + fmt.Fprint(w, "This is a test") return 0, nil }), Rules: []Rule{