Replace our old faithful gracefulListener with Go 1.8's Shutdown()

This commit is contained in:
Matthew Holt 2017-01-24 20:05:53 -07:00
parent 16250da3f0
commit 139a3cfb13
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5
2 changed files with 11 additions and 130 deletions

View file

@ -1,80 +0,0 @@
package httpserver
import (
"net"
"sync"
"syscall"
)
// TODO: Should this be a generic graceful listener available in its own package or something?
// Also, passing in a WaitGroup is a little awkward. Why can't this listener just keep
// the waitgroup internal to itself?
// newGracefulListener returns a gracefulListener that wraps l and
// uses wg (stored in the host server) to count connections.
func newGracefulListener(l net.Listener, wg *sync.WaitGroup) *gracefulListener {
gl := &gracefulListener{Listener: l, stop: make(chan error), connWg: wg}
go func() {
<-gl.stop
gl.Lock()
gl.stopped = true
gl.Unlock()
gl.stop <- gl.Listener.Close()
}()
return gl
}
// gracefuListener is a net.Listener which can
// count the number of connections on it. Its
// methods mainly wrap net.Listener to be graceful.
type gracefulListener struct {
net.Listener
stop chan error
stopped bool
sync.Mutex // protects the stopped flag
connWg *sync.WaitGroup // pointer to the host's wg used for counting connections
}
// Accept accepts a connection.
func (gl *gracefulListener) Accept() (c net.Conn, err error) {
c, err = gl.Listener.Accept()
if err != nil {
return
}
c = gracefulConn{Conn: c, connWg: gl.connWg}
gl.connWg.Add(1)
return
}
// Close immediately closes the listener.
func (gl *gracefulListener) Close() error {
gl.Lock()
if gl.stopped {
gl.Unlock()
return syscall.EINVAL
}
gl.Unlock()
gl.stop <- nil
return <-gl.stop
}
// gracefulConn represents a connection on a
// gracefulListener so that we can keep track
// of the number of connections, thus facilitating
// a graceful shutdown.
type gracefulConn struct {
net.Conn
connWg *sync.WaitGroup // pointer to the host server's connection waitgroup
}
// Close closes c's underlying connection while updating the wg count.
func (c gracefulConn) Close() error {
err := c.Conn.Close()
if err != nil {
return err
}
// close can fail on http2 connections (as of Oct. 2015, before http2 in std lib)
// so don't decrement count unless close succeeds
c.connWg.Done()
return nil
}

View file

@ -2,6 +2,7 @@
package httpserver package httpserver
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
@ -27,9 +28,8 @@ type Server struct {
listener net.Listener listener net.Listener
listenerMu sync.Mutex listenerMu sync.Mutex
sites []*SiteConfig sites []*SiteConfig
connTimeout time.Duration // max time to wait for a connection before force stop connTimeout time.Duration // max time to wait for a connection before force stop
connWg sync.WaitGroup // one increment per connection tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
vhosts *vhostTrie vhosts *vhostTrie
} }
@ -46,16 +46,6 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
connTimeout: GracefulTimeout, connTimeout: GracefulTimeout,
} }
s.Server.Handler = s // this is weird, but whatever s.Server.Handler = s // this is weird, but whatever
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
if cs == http.StateIdle {
s.listenerMu.Lock()
// server stopped, close idle connection
if s.listener == nil {
c.Close()
}
s.listenerMu.Unlock()
}
}
// Disable HTTP/2 if desired // Disable HTTP/2 if desired
if !HTTP2 { if !HTTP2 {
@ -68,14 +58,6 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler) s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler)
} }
// We have to bound our wg with one increment
// to prevent a "race condition" that is hard-coded
// into sync.WaitGroup.Wait() - basically, an add
// with a positive delta must be guaranteed to
// occur before Wait() is called on the wg.
// In a way, this kind of acts as a safety barrier.
s.connWg.Add(1)
// Set up TLS configuration // Set up TLS configuration
var tlsConfigs []*caddytls.Config var tlsConfigs []*caddytls.Config
for _, site := range group { for _, site := range group {
@ -154,8 +136,6 @@ func (s *Server) Serve(ln net.Listener) error {
ln = tcpKeepAliveListener{TCPListener: tcpLn} ln = tcpKeepAliveListener{TCPListener: tcpLn}
} }
ln = newGracefulListener(ln, &s.connWg)
s.listenerMu.Lock() s.listenerMu.Lock()
s.listener = ln s.listener = ln
s.listenerMu.Unlock() s.listenerMu.Unlock()
@ -300,40 +280,21 @@ func (s *Server) Address() string {
// Stop stops s gracefully (or forcefully after timeout) and // Stop stops s gracefully (or forcefully after timeout) and
// closes its listener. // closes its listener.
func (s *Server) Stop() (err error) { func (s *Server) Stop() error {
s.Server.SetKeepAlivesEnabled(false) ctx, cancel := context.WithTimeout(context.Background(), s.connTimeout)
defer cancel()
if runtime.GOOS != "windows" { err := s.Server.Shutdown(ctx)
// force connections to close after timeout if err != nil {
done := make(chan struct{}) return err
go func() {
s.connWg.Done() // decrement our initial increment used as a barrier
s.connWg.Wait()
close(done)
}()
// Wait for remaining connections to finish or
// force them all to close after timeout
select {
case <-time.After(s.connTimeout):
case <-done:
}
} }
// Close the listener now; this stops the server without delay // signal any TLS governor goroutines to exit
s.listenerMu.Lock()
if s.listener != nil {
err = s.listener.Close()
s.listener = nil
}
s.listenerMu.Unlock()
// Closing this signals any TLS governor goroutines to exit
if s.tlsGovChan != nil { if s.tlsGovChan != nil {
close(s.tlsGovChan) close(s.tlsGovChan)
} }
return return nil
} }
// sanitizePath collapses any ./ ../ /// madness // sanitizePath collapses any ./ ../ /// madness