Fix deferred header ops

This commit is contained in:
Matthew Holt 2019-05-20 22:00:54 -06:00
parent a969872850
commit b84cb05848
2 changed files with 24 additions and 10 deletions

View file

@ -17,16 +17,16 @@ func init() {
// Headers is a middleware which can mutate HTTP headers. // Headers is a middleware which can mutate HTTP headers.
type Headers struct { type Headers struct {
Request HeaderOps Request HeaderOps `json:"request"`
Response RespHeaderOps Response RespHeaderOps `json:"response"`
} }
// HeaderOps defines some operations to // HeaderOps defines some operations to
// perform on HTTP headers. // perform on HTTP headers.
type HeaderOps struct { type HeaderOps struct {
Add http.Header Add http.Header `json:"add"`
Set http.Header Set http.Header `json:"set"`
Delete []string Delete []string `json:"delete"`
} }
// RespHeaderOps is like HeaderOps, but // RespHeaderOps is like HeaderOps, but
@ -67,10 +67,22 @@ func apply(ops HeaderOps, hdr http.Header) {
// operations until WriteHeader is called. // operations until WriteHeader is called.
type responseWriterWrapper struct { type responseWriterWrapper struct {
*caddyhttp.ResponseWriterWrapper *caddyhttp.ResponseWriterWrapper
headerOps HeaderOps headerOps HeaderOps
wroteHeader bool
}
func (rww *responseWriterWrapper) Write(d []byte) (int, error) {
if !rww.wroteHeader {
rww.WriteHeader(http.StatusOK)
}
return rww.ResponseWriterWrapper.Write(d)
} }
func (rww *responseWriterWrapper) WriteHeader(status int) { func (rww *responseWriterWrapper) WriteHeader(status int) {
if rww.wroteHeader {
return
}
rww.wroteHeader = true
apply(rww.headerOps, rww.ResponseWriterWrapper.Header()) apply(rww.headerOps, rww.ResponseWriterWrapper.Header())
rww.ResponseWriterWrapper.WriteHeader(status) rww.ResponseWriterWrapper.WriteHeader(status)
} }

View file

@ -17,7 +17,7 @@ func init() {
// Static implements a simple responder for static responses. // Static implements a simple responder for static responses.
type Static struct { type Static struct {
StatusCode int `json:"status_code"` StatusCode int `json:"status_code"` // TODO: should we turn this into a string so that only one field is needed? (string allows replacements)
StatusCodeStr string `json:"status_code_str"` StatusCodeStr string `json:"status_code_str"`
Headers http.Header `json:"headers"` Headers http.Header `json:"headers"`
Body string `json:"body"` Body string `json:"body"`
@ -30,7 +30,7 @@ func (s Static) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
// close the connection after responding // close the connection after responding
r.Close = s.Close r.Close = s.Close
// set all headers, with replacements // set all headers
for field, vals := range s.Headers { for field, vals := range s.Headers {
field = repl.ReplaceAll(field, "") field = repl.ReplaceAll(field, "")
for i := range vals { for i := range vals {
@ -39,7 +39,7 @@ func (s Static) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
w.Header()[field] = vals w.Header()[field] = vals
} }
// write the headers with a status code // get the status code
statusCode := s.StatusCode statusCode := s.StatusCode
if statusCode == 0 && s.StatusCodeStr != "" { if statusCode == 0 && s.StatusCodeStr != "" {
intVal, err := strconv.Atoi(repl.ReplaceAll(s.StatusCodeStr, "")) intVal, err := strconv.Atoi(repl.ReplaceAll(s.StatusCodeStr, ""))
@ -50,9 +50,11 @@ func (s Static) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
if statusCode == 0 { if statusCode == 0 {
statusCode = http.StatusOK statusCode = http.StatusOK
} }
// write headers
w.WriteHeader(statusCode) w.WriteHeader(statusCode)
// write the response body, with replacements // write response body
if s.Body != "" { if s.Body != "" {
fmt.Fprint(w, repl.ReplaceAll(s.Body, "")) fmt.Fprint(w, repl.ReplaceAll(s.Body, ""))
} }