diff --git a/caddyhttp/errors/errors.go b/caddyhttp/errors/errors.go index 568bab87..4120780b 100644 --- a/caddyhttp/errors/errors.go +++ b/caddyhttp/errors/errors.go @@ -18,8 +18,10 @@ package errors import ( "fmt" "io" + "mime" "net/http" "os" + "path/filepath" "runtime" "strings" @@ -83,9 +85,13 @@ func (h ErrorHandler) errorPage(w http.ResponseWriter, r *http.Request, code int return } defer errorPage.Close() - + // Get content type by extension + contentType := mime.TypeByExtension(filepath.Ext(pagePath)) + if contentType == "" { + contentType = "text/html; charset=utf-8" + } // Copy the page body into the response - w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Header().Set("Content-Type", contentType) w.WriteHeader(code) _, err = io.Copy(w, errorPage) diff --git a/caddyhttp/errors/errors_test.go b/caddyhttp/errors/errors_test.go index 0409eaa0..afc43357 100644 --- a/caddyhttp/errors/errors_test.go +++ b/caddyhttp/errors/errors_test.go @@ -167,6 +167,7 @@ func TestVisibleErrorWithPanic(t *testing.T) { func TestGenericErrorPage(t *testing.T) { // create temporary generic error page const genericErrorContent = "This is a generic error page" + const badRequestErrorJSONContent = `{"message":"This is a error json message"}` genericErrorPagePath, err := createErrorPageFile("generic_error_test.html", genericErrorContent) if err != nil { @@ -183,11 +184,18 @@ func TestGenericErrorPage(t *testing.T) { } defer os.Remove(notFoundErrorPagePath) + badRequestErrorJSONPath, err := createErrorPageFile("not_found.json", badRequestErrorJSONContent) + if err != nil { + t.Fatal(err) + } + defer os.Remove(badRequestErrorJSONPath) + buf := bytes.Buffer{} em := ErrorHandler{ GenericErrorPage: genericErrorPagePath, ErrorPages: map[int]string{ http.StatusNotFound: notFoundErrorPagePath, + http.StatusBadRequest:badRequestErrorJSONPath, }, Log: httpserver.NewTestLogger(&buf), } @@ -198,6 +206,7 @@ func TestGenericErrorPage(t *testing.T) { expectedBody string expectedLog string expectedErr error + expectedContentType string }{ { next: genErrorHandler(http.StatusNotFound, nil, ""), @@ -205,6 +214,15 @@ func TestGenericErrorPage(t *testing.T) { expectedBody: notFoundErrorContent, expectedLog: "", expectedErr: nil, + expectedContentType: "text/html; charset=utf-8", + }, + { + next: genErrorHandler(http.StatusBadRequest, nil, ""), + expectedCode: 0, + expectedBody: badRequestErrorJSONContent, + expectedLog: "", + expectedErr: nil, + expectedContentType: "application/json", }, { next: genErrorHandler(http.StatusInternalServerError, nil, ""), @@ -212,6 +230,7 @@ func TestGenericErrorPage(t *testing.T) { expectedBody: genericErrorContent, expectedLog: "", expectedErr: nil, + expectedContentType:"text/html; charset=utf-8", }, } @@ -238,6 +257,10 @@ func TestGenericErrorPage(t *testing.T) { t.Errorf("Test %d: Expected body %q, but got %q", i, test.expectedBody, body) } + if contentType := rec.Header().Get("Content-Type"); contentType != test.expectedContentType{ + t.Errorf("Test %d: Expected Content-Type %s, but got %s", + i, test.expectedContentType, contentType) + } if log := buf.String(); !strings.Contains(log, test.expectedLog) { t.Errorf("Test %d: Expected log %q, but got %q", i, test.expectedLog, log)