diff --git a/middleware/proxy/policy.go b/middleware/proxy/policy.go index a2522bcb..96e382a5 100644 --- a/middleware/proxy/policy.go +++ b/middleware/proxy/policy.go @@ -25,11 +25,11 @@ type Random struct{} // Select selects an up host at random from the specified pool. func (r *Random) Select(pool HostPool) *UpstreamHost { // instead of just generating a random index - // this is done to prevent selecting a down host + // this is done to prevent selecting a unavailable host var randHost *UpstreamHost count := 0 for _, host := range pool { - if host.Down() { + if !host.Available() { continue } count++ @@ -56,7 +56,7 @@ func (r *LeastConn) Select(pool HostPool) *UpstreamHost { count := 0 leastConn := int64(1<<63 - 1) for _, host := range pool { - if host.Down() { + if !host.Available() { continue } hostConns := host.Conns @@ -90,11 +90,11 @@ func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { poolLen := uint32(len(pool)) selection := atomic.AddUint32(&r.Robin, 1) % poolLen host := pool[selection] - // if the currently selected host is down, just ffwd to up host - for i := uint32(1); host.Down() && i < poolLen; i++ { + // if the currently selected host is not available, just ffwd to up host + for i := uint32(1); !host.Available() && i < poolLen; i++ { host = pool[(selection+i)%poolLen] } - if host.Down() { + if !host.Available() { return nil } return host diff --git a/middleware/proxy/policy_test.go b/middleware/proxy/policy_test.go index 8f4f1f79..4cc05f02 100644 --- a/middleware/proxy/policy_test.go +++ b/middleware/proxy/policy_test.go @@ -53,12 +53,23 @@ func TestRoundRobinPolicy(t *testing.T) { if h != pool[2] { t.Error("Expected second round robin host to be third host in the pool.") } - // mark host as down - pool[0].Unhealthy = true h = rrPolicy.Select(pool) - if h != pool[1] { + if h != pool[0] { t.Error("Expected third round robin host to be first host in the pool.") } + // mark host as down + pool[1].Unhealthy = true + h = rrPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected to skip down host.") + } + // mark host as full + pool[2].Conns = 1 + pool[2].MaxConns = 1 + h = rrPolicy.Select(pool) + if h != pool[0] { + t.Error("Expected to skip full host.") + } } func TestLeastConnPolicy(t *testing.T) { diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 7be8af2a..583f4a63 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -44,6 +44,7 @@ type UpstreamHost struct { ExtraHeaders http.Header CheckDown UpstreamHostDownFunc WithoutPathPrefix string + MaxConns int64 } // Down checks whether the upstream host is down or not. @@ -57,6 +58,16 @@ func (uh *UpstreamHost) Down() bool { return uh.CheckDown(uh) } +// Full checks whether the upstream host has reached its maximum connections +func (uh *UpstreamHost) Full() bool { + return uh.MaxConns > 0 && uh.Conns >= uh.MaxConns +} + +// Available checks whether the upstream host is available for proxying to +func (uh *UpstreamHost) Available() bool { + return !uh.Down() && !uh.Full() +} + // tryDuration is how long to try upstream hosts; failures result in // immediate retries until this duration ends or we get a nil host. var tryDuration = 60 * time.Second diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index 1182a0f4..35c95f1f 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -80,10 +80,6 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { ExtraHeaders: upstream.proxyHeaders, CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc { return func(uh *UpstreamHost) bool { - if upstream.MaxConns != 0 && - uh.Conns >= upstream.MaxConns { - return true - } if uh.Unhealthy { return true } @@ -95,6 +91,7 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { } }(upstream), WithoutPathPrefix: upstream.WithoutPathPrefix, + MaxConns: upstream.MaxConns, } if baseURL, err := url.Parse(uh.Name); err == nil { uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix) @@ -234,19 +231,19 @@ func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) { func (u *staticUpstream) Select() *UpstreamHost { pool := u.Hosts if len(pool) == 1 { - if pool[0].Down() { + if !pool[0].Available() { return nil } return pool[0] } - allDown := true + allUnavailable := true for _, host := range pool { - if !host.Down() { - allDown = false + if host.Available() { + allUnavailable = false break } } - if allDown { + if allUnavailable { return nil } diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go index 5b2fdb1d..32f6df9e 100644 --- a/middleware/proxy/upstream_test.go +++ b/middleware/proxy/upstream_test.go @@ -40,6 +40,19 @@ func TestSelect(t *testing.T) { if h := upstream.Select(); h == nil { t.Error("Expected select to not return nil") } + upstream.Hosts[0].Conns = 1 + upstream.Hosts[0].MaxConns = 1 + upstream.Hosts[1].Conns = 1 + upstream.Hosts[1].MaxConns = 1 + upstream.Hosts[2].Conns = 1 + upstream.Hosts[2].MaxConns = 1 + if h := upstream.Select(); h != nil { + t.Error("Expected select to return nil as all hosts are full") + } + upstream.Hosts[2].Conns = 0 + if h := upstream.Select(); h == nil { + t.Error("Expected select to not return nil") + } } func TestRegisterPolicy(t *testing.T) {