caddy/caddyhttp/proxy/upstream.go
Angel Santiago 59bf71c293 proxy: Cleanly shutdown health checks on restart (#1524)
* Add a shutdown function and context to staticUpstream so that running goroutines can be cancelled. Add a GetShutdownFunc to Upstream interface to expose the shutdown function to the caddy Controller for performing it on restarts.

* Make fakeUpstream implement new Upstream methods.

Implement new Upstream method for fakeWSUpstream as well.

* Rename GetShutdownFunc to Stop(). Add a waitgroup to the staticUpstream for controlling individual object's goroutines. Add the Stop function to OnRestart and OnShutdown. Add tests for checking to see if healthchecks continue hitting a backend server after stop has been called.

* Go back to using a stop channel since the context adds no additional benefit.
Only register stop function for onShutdown since it's called as part of restart.

* Remove assignment to atomic value

* Incrementing WaitGroup outside of goroutine to avoid race condition. Loading atomic values in test.

* Linting: change counter to just use the default zero value instead of setting it

* Clarify Stop method comments, add comments to stop channel and waitgroup and remove out of date comment about handling stopping the proxy. Stop the ticker when the stop signal is sent
2017-04-02 14:58:15 -06:00

456 lines
10 KiB
Go

package proxy
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/mholt/caddy/caddyfile"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
var (
supportedPolicies = make(map[string]func() Policy)
)
type staticUpstream struct {
from string
upstreamHeaders http.Header
downstreamHeaders http.Header
stop chan struct{} // Signals running goroutines to stop.
wg sync.WaitGroup // Used to wait for running goroutines to stop.
Hosts HostPool
Policy Policy
KeepAlive int
FailTimeout time.Duration
TryDuration time.Duration
TryInterval time.Duration
MaxConns int64
HealthCheck struct {
Client http.Client
Path string
Interval time.Duration
Timeout time.Duration
}
WithoutPathPrefix string
IgnoredSubPaths []string
insecureSkipVerify bool
MaxFails int32
}
// NewStaticUpstreams parses the configuration input and sets up
// static upstreams for the proxy middleware.
func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) {
var upstreams []Upstream
for c.Next() {
upstream := &staticUpstream{
from: "",
stop: make(chan struct{}),
upstreamHeaders: make(http.Header),
downstreamHeaders: make(http.Header),
Hosts: nil,
Policy: &Random{},
MaxFails: 1,
TryInterval: 250 * time.Millisecond,
MaxConns: 0,
KeepAlive: http.DefaultMaxIdleConnsPerHost,
}
if !c.Args(&upstream.from) {
return upstreams, c.ArgErr()
}
var to []string
for _, t := range c.RemainingArgs() {
parsed, err := parseUpstream(t)
if err != nil {
return upstreams, err
}
to = append(to, parsed...)
}
for c.NextBlock() {
switch c.Val() {
case "upstream":
if !c.NextArg() {
return upstreams, c.ArgErr()
}
parsed, err := parseUpstream(c.Val())
if err != nil {
return upstreams, err
}
to = append(to, parsed...)
default:
if err := parseBlock(&c, upstream); err != nil {
return upstreams, err
}
}
}
if len(to) == 0 {
return upstreams, c.ArgErr()
}
upstream.Hosts = make([]*UpstreamHost, len(to))
for i, host := range to {
uh, err := upstream.NewHost(host)
if err != nil {
return upstreams, err
}
upstream.Hosts[i] = uh
}
if upstream.HealthCheck.Path != "" {
upstream.HealthCheck.Client = http.Client{
Timeout: upstream.HealthCheck.Timeout,
}
upstream.wg.Add(1)
go func() {
defer upstream.wg.Done()
upstream.HealthCheckWorker(upstream.stop)
}()
}
upstreams = append(upstreams, upstream)
}
return upstreams, nil
}
func (u *staticUpstream) From() string {
return u.from
}
func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
if !strings.HasPrefix(host, "http") &&
!strings.HasPrefix(host, "unix:") {
host = "http://" + host
}
uh := &UpstreamHost{
Name: host,
Conns: 0,
Fails: 0,
FailTimeout: u.FailTimeout,
Unhealthy: 0,
UpstreamHeaders: u.upstreamHeaders,
DownstreamHeaders: u.downstreamHeaders,
CheckDown: func(u *staticUpstream) UpstreamHostDownFunc {
return func(uh *UpstreamHost) bool {
if atomic.LoadInt32(&uh.Unhealthy) != 0 {
return true
}
if atomic.LoadInt32(&uh.Fails) >= u.MaxFails {
return true
}
return false
}
}(u),
WithoutPathPrefix: u.WithoutPathPrefix,
MaxConns: u.MaxConns,
}
baseURL, err := url.Parse(uh.Name)
if err != nil {
return nil, err
}
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive)
if u.insecureSkipVerify {
uh.ReverseProxy.UseInsecureTransport()
}
return uh, nil
}
func parseUpstream(u string) ([]string, error) {
if !strings.HasPrefix(u, "unix:") {
colonIdx := strings.LastIndex(u, ":")
protoIdx := strings.Index(u, "://")
if colonIdx != -1 && colonIdx != protoIdx {
us := u[:colonIdx]
ue := ""
portsEnd := len(u)
if nextSlash := strings.Index(u[colonIdx:], "/"); nextSlash != -1 {
portsEnd = colonIdx + nextSlash
ue = u[portsEnd:]
}
ports := u[len(us)+1 : portsEnd]
if separators := strings.Count(ports, "-"); separators == 1 {
portsStr := strings.Split(ports, "-")
pIni, err := strconv.Atoi(portsStr[0])
if err != nil {
return nil, err
}
pEnd, err := strconv.Atoi(portsStr[1])
if err != nil {
return nil, err
}
if pEnd <= pIni {
return nil, fmt.Errorf("port range [%s] is invalid", ports)
}
hosts := []string{}
for p := pIni; p <= pEnd; p++ {
hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue))
}
return hosts, nil
}
}
}
return []string{u}, nil
}
func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
switch c.Val() {
case "policy":
if !c.NextArg() {
return c.ArgErr()
}
policyCreateFunc, ok := supportedPolicies[c.Val()]
if !ok {
return c.ArgErr()
}
u.Policy = policyCreateFunc()
case "fail_timeout":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
u.FailTimeout = dur
case "max_fails":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.Atoi(c.Val())
if err != nil {
return err
}
if n < 1 {
return c.Err("max_fails must be at least 1")
}
u.MaxFails = int32(n)
case "try_duration":
if !c.NextArg() {
return c.ArgErr()
}
dur, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
u.TryDuration = dur
case "try_interval":
if !c.NextArg() {
return c.ArgErr()
}
interval, err := time.ParseDuration(c.Val())
if err != nil {
return err
}
u.TryInterval = interval
case "max_conns":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.ParseInt(c.Val(), 10, 64)
if err != nil {
return err
}
u.MaxConns = n
case "health_check":
if !c.NextArg() {
return c.ArgErr()
}
u.HealthCheck.Path = c.Val()
// Set defaults
if u.HealthCheck.Interval == 0 {
u.HealthCheck.Interval = 30 * time.Second
}
if u.HealthCheck.Timeout == 0 {
u.HealthCheck.Timeout = 60 * time.Second
}
case "health_check_interval":
var interval string
if !c.Args(&interval) {
return c.ArgErr()
}
dur, err := time.ParseDuration(interval)
if err != nil {
return err
}
u.HealthCheck.Interval = dur
case "health_check_timeout":
var interval string
if !c.Args(&interval) {
return c.ArgErr()
}
dur, err := time.ParseDuration(interval)
if err != nil {
return err
}
u.HealthCheck.Timeout = dur
case "header_upstream":
var header, value string
if !c.Args(&header, &value) {
// When removing a header, the value can be optional.
if !strings.HasPrefix(header, "-") {
return c.ArgErr()
}
}
u.upstreamHeaders.Add(header, value)
case "header_downstream":
var header, value string
if !c.Args(&header, &value) {
// When removing a header, the value can be optional.
if !strings.HasPrefix(header, "-") {
return c.ArgErr()
}
}
u.downstreamHeaders.Add(header, value)
case "transparent":
u.upstreamHeaders.Add("Host", "{host}")
u.upstreamHeaders.Add("X-Real-IP", "{remote}")
u.upstreamHeaders.Add("X-Forwarded-For", "{remote}")
u.upstreamHeaders.Add("X-Forwarded-Proto", "{scheme}")
case "websocket":
u.upstreamHeaders.Add("Connection", "{>Connection}")
u.upstreamHeaders.Add("Upgrade", "{>Upgrade}")
case "without":
if !c.NextArg() {
return c.ArgErr()
}
u.WithoutPathPrefix = c.Val()
case "except":
ignoredPaths := c.RemainingArgs()
if len(ignoredPaths) == 0 {
return c.ArgErr()
}
u.IgnoredSubPaths = ignoredPaths
case "insecure_skip_verify":
u.insecureSkipVerify = true
case "keepalive":
if !c.NextArg() {
return c.ArgErr()
}
n, err := strconv.Atoi(c.Val())
if err != nil {
return err
}
if n < 0 {
return c.ArgErr()
}
u.KeepAlive = n
default:
return c.Errf("unknown property '%s'", c.Val())
}
return nil
}
func (u *staticUpstream) healthCheck() {
for _, host := range u.Hosts {
hostURL := host.Name + u.HealthCheck.Path
var unhealthy bool
if r, err := u.HealthCheck.Client.Get(hostURL); err == nil {
io.Copy(ioutil.Discard, r.Body)
r.Body.Close()
unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
} else {
unhealthy = true
}
if unhealthy {
atomic.StoreInt32(&host.Unhealthy, 1)
} else {
atomic.StoreInt32(&host.Unhealthy, 0)
}
}
}
func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
ticker := time.NewTicker(u.HealthCheck.Interval)
u.healthCheck()
for {
select {
case <-ticker.C:
u.healthCheck()
case <-stop:
ticker.Stop()
return
}
}
}
func (u *staticUpstream) Select(r *http.Request) *UpstreamHost {
pool := u.Hosts
if len(pool) == 1 {
if !pool[0].Available() {
return nil
}
return pool[0]
}
allUnavailable := true
for _, host := range pool {
if host.Available() {
allUnavailable = false
break
}
}
if allUnavailable {
return nil
}
if u.Policy == nil {
return (&Random{}).Select(pool, r)
}
return u.Policy.Select(pool, r)
}
func (u *staticUpstream) AllowedPath(requestPath string) bool {
for _, ignoredSubPath := range u.IgnoredSubPaths {
if httpserver.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) {
return false
}
}
return true
}
// GetTryDuration returns u.TryDuration.
func (u *staticUpstream) GetTryDuration() time.Duration {
return u.TryDuration
}
// GetTryInterval returns u.TryInterval.
func (u *staticUpstream) GetTryInterval() time.Duration {
return u.TryInterval
}
func (u *staticUpstream) GetHostCount() int {
return len(u.Hosts)
}
// Stop sends a signal to all goroutines started by this staticUpstream to exit
// and waits for them to finish before returning.
func (u *staticUpstream) Stop() error {
close(u.stop)
u.wg.Wait()
return nil
}
// RegisterPolicy adds a custom policy to the proxy.
func RegisterPolicy(name string, policy func() Policy) {
supportedPolicies[name] = policy
}