diff --git a/middleware/proxy/policy.go b/middleware/proxy/policy.go index a2522bcb1..96e382a5c 100644 --- a/middleware/proxy/policy.go +++ b/middleware/proxy/policy.go @@ -25,11 +25,11 @@ type Random struct{} // Select selects an up host at random from the specified pool. func (r *Random) Select(pool HostPool) *UpstreamHost { // instead of just generating a random index - // this is done to prevent selecting a down host + // this is done to prevent selecting a unavailable host var randHost *UpstreamHost count := 0 for _, host := range pool { - if host.Down() { + if !host.Available() { continue } count++ @@ -56,7 +56,7 @@ func (r *LeastConn) Select(pool HostPool) *UpstreamHost { count := 0 leastConn := int64(1<<63 - 1) for _, host := range pool { - if host.Down() { + if !host.Available() { continue } hostConns := host.Conns @@ -90,11 +90,11 @@ func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { poolLen := uint32(len(pool)) selection := atomic.AddUint32(&r.Robin, 1) % poolLen host := pool[selection] - // if the currently selected host is down, just ffwd to up host - for i := uint32(1); host.Down() && i < poolLen; i++ { + // if the currently selected host is not available, just ffwd to up host + for i := uint32(1); !host.Available() && i < poolLen; i++ { host = pool[(selection+i)%poolLen] } - if host.Down() { + if !host.Available() { return nil } return host diff --git a/middleware/proxy/policy_test.go b/middleware/proxy/policy_test.go index 8f4f1f792..4cc05f029 100644 --- a/middleware/proxy/policy_test.go +++ b/middleware/proxy/policy_test.go @@ -53,12 +53,23 @@ func TestRoundRobinPolicy(t *testing.T) { if h != pool[2] { t.Error("Expected second round robin host to be third host in the pool.") } - // mark host as down - pool[0].Unhealthy = true h = rrPolicy.Select(pool) - if h != pool[1] { + if h != pool[0] { t.Error("Expected third round robin host to be first host in the pool.") } + // mark host as down + pool[1].Unhealthy = true + h = rrPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected to skip down host.") + } + // mark host as full + pool[2].Conns = 1 + pool[2].MaxConns = 1 + h = rrPolicy.Select(pool) + if h != pool[0] { + t.Error("Expected to skip full host.") + } } func TestLeastConnPolicy(t *testing.T) { diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 7be8af2ad..583f4a635 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -44,6 +44,7 @@ type UpstreamHost struct { ExtraHeaders http.Header CheckDown UpstreamHostDownFunc WithoutPathPrefix string + MaxConns int64 } // Down checks whether the upstream host is down or not. @@ -57,6 +58,16 @@ func (uh *UpstreamHost) Down() bool { return uh.CheckDown(uh) } +// Full checks whether the upstream host has reached its maximum connections +func (uh *UpstreamHost) Full() bool { + return uh.MaxConns > 0 && uh.Conns >= uh.MaxConns +} + +// Available checks whether the upstream host is available for proxying to +func (uh *UpstreamHost) Available() bool { + return !uh.Down() && !uh.Full() +} + // tryDuration is how long to try upstream hosts; failures result in // immediate retries until this duration ends or we get a nil host. var tryDuration = 60 * time.Second diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index faa11cd92..29b3d2d8f 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -27,6 +27,7 @@ type staticUpstream struct { FailTimeout time.Duration MaxFails int32 + MaxConns int64 HealthCheck struct { Path string Interval time.Duration @@ -47,6 +48,7 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { Policy: &Random{}, FailTimeout: 10 * time.Second, MaxFails: 1, + MaxConns: 0, } if !c.Args(&upstream.from) { @@ -65,37 +67,8 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { upstream.Hosts = make([]*UpstreamHost, len(to)) for i, host := range to { - if !strings.HasPrefix(host, "http") && - !strings.HasPrefix(host, "unix:") { - host = "http://" + host - } - uh := &UpstreamHost{ - Name: host, - Conns: 0, - Fails: 0, - FailTimeout: upstream.FailTimeout, - Unhealthy: false, - ExtraHeaders: upstream.proxyHeaders, - CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc { - return func(uh *UpstreamHost) bool { - if uh.Unhealthy { - return true - } - if uh.Fails >= upstream.MaxFails && - upstream.MaxFails != 0 { - return true - } - return false - } - }(upstream), - WithoutPathPrefix: upstream.WithoutPathPrefix, - } - if baseURL, err := url.Parse(uh.Name); err == nil { - uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix) - if upstream.insecureSkipVerify { - uh.ReverseProxy.Transport = InsecureTransport - } - } else { + uh, err := upstream.NewHost(host) + if err != nil { return upstreams, err } upstream.Hosts[i] = uh @@ -118,6 +91,46 @@ func (u *staticUpstream) From() string { return u.from } +func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { + if !strings.HasPrefix(host, "http") && + !strings.HasPrefix(host, "unix:") { + host = "http://" + host + } + uh := &UpstreamHost{ + Name: host, + Conns: 0, + Fails: 0, + FailTimeout: u.FailTimeout, + Unhealthy: false, + ExtraHeaders: u.proxyHeaders, + CheckDown: func(u *staticUpstream) UpstreamHostDownFunc { + return func(uh *UpstreamHost) bool { + if uh.Unhealthy { + return true + } + if uh.Fails >= u.MaxFails && + u.MaxFails != 0 { + return true + } + return false + } + }(u), + WithoutPathPrefix: u.WithoutPathPrefix, + MaxConns: u.MaxConns, + } + + baseURL, err := url.Parse(uh.Name) + if err != nil { + return nil, err + } + + uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix) + if u.insecureSkipVerify { + uh.ReverseProxy.Transport = InsecureTransport + } + return uh, nil +} + func parseBlock(c *parse.Dispenser, u *staticUpstream) error { switch c.Val() { case "policy": @@ -147,6 +160,15 @@ func parseBlock(c *parse.Dispenser, u *staticUpstream) error { return err } u.MaxFails = int32(n) + case "max_conns": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.ParseInt(c.Val(), 10, 64) + if err != nil { + return err + } + u.MaxConns = n case "health_check": if !c.NextArg() { return c.ArgErr() @@ -219,19 +241,19 @@ func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) { func (u *staticUpstream) Select() *UpstreamHost { pool := u.Hosts if len(pool) == 1 { - if pool[0].Down() { + if !pool[0].Available() { return nil } return pool[0] } - allDown := true + allUnavailable := true for _, host := range pool { - if !host.Down() { - allDown = false + if host.Available() { + allUnavailable = false break } } - if allDown { + if allUnavailable { return nil } diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go index 5b2fdb1da..0370b6a43 100644 --- a/middleware/proxy/upstream_test.go +++ b/middleware/proxy/upstream_test.go @@ -5,6 +5,45 @@ import ( "time" ) +func TestNewHost(t *testing.T) { + upstream := &staticUpstream{ + FailTimeout: 10 * time.Second, + MaxConns: 1, + MaxFails: 1, + } + + uh, err := upstream.NewHost("example.com") + if err != nil { + t.Error("Expected no error") + } + if uh.Name != "http://example.com" { + t.Error("Expected default schema to be added to Name.") + } + if uh.FailTimeout != upstream.FailTimeout { + t.Error("Expected default FailTimeout to be set.") + } + if uh.MaxConns != upstream.MaxConns { + t.Error("Expected default MaxConns to be set.") + } + if uh.CheckDown == nil { + t.Error("Expected default CheckDown to be set.") + } + if uh.CheckDown(uh) { + t.Error("Expected new host not to be down.") + } + // mark Unhealthy + uh.Unhealthy = true + if !uh.CheckDown(uh) { + t.Error("Expected unhealthy host to be down.") + } + // mark with Fails + uh.Unhealthy = false + uh.Fails = 1 + if !uh.CheckDown(uh) { + t.Error("Expected failed host to be down.") + } +} + func TestHealthCheck(t *testing.T) { upstream := &staticUpstream{ from: "", @@ -40,6 +79,19 @@ func TestSelect(t *testing.T) { if h := upstream.Select(); h == nil { t.Error("Expected select to not return nil") } + upstream.Hosts[0].Conns = 1 + upstream.Hosts[0].MaxConns = 1 + upstream.Hosts[1].Conns = 1 + upstream.Hosts[1].MaxConns = 1 + upstream.Hosts[2].Conns = 1 + upstream.Hosts[2].MaxConns = 1 + if h := upstream.Select(); h != nil { + t.Error("Expected select to return nil as all hosts are full") + } + upstream.Hosts[2].Conns = 0 + if h := upstream.Select(); h == nil { + t.Error("Expected select to not return nil") + } } func TestRegisterPolicy(t *testing.T) {