Fix deferred header ops

This commit is contained in:
Matthew Holt 2019-05-20 22:00:54 -06:00
parent a969872850
commit b84cb05848
2 changed files with 24 additions and 10 deletions

View file

@ -17,16 +17,16 @@ func init() {
// Headers is a middleware which can mutate HTTP headers.
type Headers struct {
Request HeaderOps
Response RespHeaderOps
Request HeaderOps `json:"request"`
Response RespHeaderOps `json:"response"`
}
// HeaderOps defines some operations to
// perform on HTTP headers.
type HeaderOps struct {
Add http.Header
Set http.Header
Delete []string
Add http.Header `json:"add"`
Set http.Header `json:"set"`
Delete []string `json:"delete"`
}
// RespHeaderOps is like HeaderOps, but
@ -67,10 +67,22 @@ func apply(ops HeaderOps, hdr http.Header) {
// operations until WriteHeader is called.
type responseWriterWrapper struct {
*caddyhttp.ResponseWriterWrapper
headerOps HeaderOps
headerOps HeaderOps
wroteHeader bool
}
func (rww *responseWriterWrapper) Write(d []byte) (int, error) {
if !rww.wroteHeader {
rww.WriteHeader(http.StatusOK)
}
return rww.ResponseWriterWrapper.Write(d)
}
func (rww *responseWriterWrapper) WriteHeader(status int) {
if rww.wroteHeader {
return
}
rww.wroteHeader = true
apply(rww.headerOps, rww.ResponseWriterWrapper.Header())
rww.ResponseWriterWrapper.WriteHeader(status)
}

View file

@ -17,7 +17,7 @@ func init() {
// Static implements a simple responder for static responses.
type Static struct {
StatusCode int `json:"status_code"`
StatusCode int `json:"status_code"` // TODO: should we turn this into a string so that only one field is needed? (string allows replacements)
StatusCodeStr string `json:"status_code_str"`
Headers http.Header `json:"headers"`
Body string `json:"body"`
@ -30,7 +30,7 @@ func (s Static) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
// close the connection after responding
r.Close = s.Close
// set all headers, with replacements
// set all headers
for field, vals := range s.Headers {
field = repl.ReplaceAll(field, "")
for i := range vals {
@ -39,7 +39,7 @@ func (s Static) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
w.Header()[field] = vals
}
// write the headers with a status code
// get the status code
statusCode := s.StatusCode
if statusCode == 0 && s.StatusCodeStr != "" {
intVal, err := strconv.Atoi(repl.ReplaceAll(s.StatusCodeStr, ""))
@ -50,9 +50,11 @@ func (s Static) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
if statusCode == 0 {
statusCode = http.StatusOK
}
// write headers
w.WriteHeader(statusCode)
// write the response body, with replacements
// write response body
if s.Body != "" {
fmt.Fprint(w, repl.ReplaceAll(s.Body, ""))
}