errors: Fix low risk race condition at server close

See issue #1371 for more information.
This commit is contained in:
Matthew Holt 2017-01-24 19:09:44 -07:00
parent 45a0e4cf49
commit 16250da3f0
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5
4 changed files with 34 additions and 7 deletions

View file

@ -9,6 +9,7 @@ import (
"os"
"runtime"
"strings"
"sync"
"time"
"github.com/mholt/caddy"
@ -30,8 +31,9 @@ type ErrorHandler struct {
LogFile string
Log *log.Logger
LogRoller *httpserver.LogRoller
Debug bool // if true, errors are written out to client rather than to a log
file *os.File // a log file to close when done
Debug bool // if true, errors are written out to client rather than to a log
file *os.File // a log file to close when done
fileMu *sync.RWMutex // like with log middleware, os.File can't "safely" be closed in a different goroutine
}
func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
@ -48,7 +50,9 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er
fmt.Fprintln(w, errMsg)
return 0, err // returning 0 signals that a response has been written
}
h.fileMu.RLock()
h.Log.Println(errMsg)
h.fileMu.RUnlock()
}
if status >= 400 {
@ -69,8 +73,10 @@ func (h ErrorHandler) errorPage(w http.ResponseWriter, r *http.Request, code int
errorPage, err := os.Open(pagePath)
if err != nil {
// An additional error handling an error... <insert grumpy cat here>
h.fileMu.RLock()
h.Log.Printf("%s [NOTICE %d %s] could not load error page: %v",
time.Now().Format(timeFormat), code, r.URL.String(), err)
h.fileMu.RUnlock()
httpserver.DefaultErrorFunc(w, r, code)
return
}
@ -83,8 +89,10 @@ func (h ErrorHandler) errorPage(w http.ResponseWriter, r *http.Request, code int
if err != nil {
// Epic fail... sigh.
h.fileMu.RLock()
h.Log.Printf("%s [NOTICE %d %s] could not respond with %s: %v",
time.Now().Format(timeFormat), code, r.URL.String(), pagePath, err)
h.fileMu.RUnlock()
httpserver.DefaultErrorFunc(w, r, code)
}
@ -146,7 +154,9 @@ func (h ErrorHandler) recovery(w http.ResponseWriter, r *http.Request) {
httpserver.WriteTextResponse(w, http.StatusInternalServerError, fmt.Sprintf("%s\n\n%s", panicMsg, stack))
} else {
// Currently we don't use the function name, since file:line is more conventional
h.fileMu.RLock()
h.Log.Printf(panicMsg)
h.fileMu.RUnlock()
h.errorPage(w, r, http.StatusInternalServerError)
}
}

View file

@ -11,6 +11,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
"github.com/mholt/caddy/caddyhttp/httpserver"
@ -32,7 +33,8 @@ func TestErrors(t *testing.T) {
http.StatusNotFound: path,
http.StatusForbidden: "not_exist_file",
},
Log: log.New(&buf, "", 0),
Log: log.New(&buf, "", 0),
fileMu: new(sync.RWMutex),
}
_, notExistErr := os.Open("not_exist_file")
@ -121,6 +123,7 @@ func TestVisibleErrorWithPanic(t *testing.T) {
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
panic(panicMsg)
}),
fileMu: new(sync.RWMutex),
}
req, err := http.NewRequest("GET", "/", nil)
@ -176,7 +179,8 @@ func TestGenericErrorPage(t *testing.T) {
ErrorPages: map[int]string{
http.StatusNotFound: notFoundErrorPagePath,
},
Log: log.New(&buf, "", 0),
Log: log.New(&buf, "", 0),
fileMu: new(sync.RWMutex),
}
tests := []struct {

View file

@ -6,6 +6,7 @@ import (
"os"
"path/filepath"
"strconv"
"sync"
"github.com/hashicorp/go-syslog"
"github.com/mholt/caddy"
@ -64,7 +65,9 @@ func setup(c *caddy.Controller) error {
// When server stops, close any open log file
c.OnShutdown(func() error {
if handler.file != nil {
handler.fileMu.Lock()
handler.file.Close()
handler.fileMu.Unlock()
}
return nil
})
@ -81,7 +84,7 @@ func errorsParse(c *caddy.Controller) (*ErrorHandler, error) {
// Very important that we make a pointer because the startup
// function that opens the log file must have access to the
// same instance of the handler, not a copy.
handler := &ErrorHandler{ErrorPages: make(map[int]string)}
handler := &ErrorHandler{ErrorPages: make(map[int]string), fileMu: new(sync.RWMutex)}
cfg := httpserver.GetConfig(c)

View file

@ -3,6 +3,7 @@ package errors
import (
"path/filepath"
"reflect"
"sync"
"testing"
"github.com/mholt/caddy"
@ -58,18 +59,22 @@ func TestErrorsParse(t *testing.T) {
}{
{`errors`, false, ErrorHandler{
ErrorPages: map[int]string{},
fileMu: new(sync.RWMutex),
}},
{`errors errors.txt`, false, ErrorHandler{
ErrorPages: map[int]string{},
LogFile: "errors.txt",
fileMu: new(sync.RWMutex),
}},
{`errors visible`, false, ErrorHandler{
ErrorPages: map[int]string{},
Debug: true,
fileMu: new(sync.RWMutex),
}},
{`errors { log visible }`, false, ErrorHandler{
ErrorPages: map[int]string{},
Debug: true,
fileMu: new(sync.RWMutex),
}},
{`errors { log errors.txt
404 404.html
@ -80,6 +85,7 @@ func TestErrorsParse(t *testing.T) {
404: "404.html",
500: "500.html",
},
fileMu: new(sync.RWMutex),
}},
{`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, ErrorHandler{
LogFile: "errors.txt",
@ -90,6 +96,7 @@ func TestErrorsParse(t *testing.T) {
LocalTime: true,
},
ErrorPages: map[int]string{},
fileMu: new(sync.RWMutex),
}},
{`errors { log errors.txt {
size 3
@ -110,6 +117,7 @@ func TestErrorsParse(t *testing.T) {
MaxBackups: 5,
LocalTime: true,
},
fileMu: new(sync.RWMutex),
}},
{`errors { log errors.txt
* generic_error.html
@ -122,6 +130,7 @@ func TestErrorsParse(t *testing.T) {
404: "404.html",
503: "503.html",
},
fileMu: new(sync.RWMutex),
}},
// test absolute file path
{`errors {
@ -131,16 +140,17 @@ func TestErrorsParse(t *testing.T) {
ErrorPages: map[int]string{
404: testAbs,
},
fileMu: new(sync.RWMutex),
}},
// Next two test cases is the detection of duplicate status codes
{`errors {
503 503.html
503 503.html
}`, true, ErrorHandler{ErrorPages: map[int]string{}}},
}`, true, ErrorHandler{ErrorPages: map[int]string{}, fileMu: new(sync.RWMutex)}},
{`errors {
* generic_error.html
* generic_error.html
}`, true, ErrorHandler{ErrorPages: map[int]string{}}},
}`, true, ErrorHandler{ErrorPages: map[int]string{}, fileMu: new(sync.RWMutex)}},
}
for i, test := range tests {
actualErrorsRule, err := errorsParse(caddy.NewTestController("http", test.inputErrorsRules))