mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-07 11:28:48 +03:00
added ip_hash load balancing
updated tests fixed comment format fixed formatting, minor logic fix added newline to EOF updated logic, fixed tests added comment updated formatting updated test output fixed typo
This commit is contained in:
parent
72af3f8256
commit
88d3dcae42
6 changed files with 189 additions and 29 deletions
|
@ -1,8 +1,11 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
@ -11,20 +14,21 @@ type HostPool []*UpstreamHost
|
|||
|
||||
// Policy decides how a host will be selected from a pool.
|
||||
type Policy interface {
|
||||
Select(pool HostPool) *UpstreamHost
|
||||
Select(pool HostPool, r *http.Request) *UpstreamHost
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterPolicy("random", func() Policy { return &Random{} })
|
||||
RegisterPolicy("least_conn", func() Policy { return &LeastConn{} })
|
||||
RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} })
|
||||
RegisterPolicy("ip_hash", func() Policy { return &IPHash{} })
|
||||
}
|
||||
|
||||
// Random is a policy that selects up hosts from a pool at random.
|
||||
type Random struct{}
|
||||
|
||||
// Select selects an up host at random from the specified pool.
|
||||
func (r *Random) Select(pool HostPool) *UpstreamHost {
|
||||
func (r *Random) Select(pool HostPool, request *http.Request) *UpstreamHost {
|
||||
|
||||
// Because the number of available hosts isn't known
|
||||
// up front, the host is selected via reservoir sampling
|
||||
|
@ -53,7 +57,7 @@ type LeastConn struct{}
|
|||
// Select selects the up host with the least number of connections in the
|
||||
// pool. If more than one host has the same least number of connections,
|
||||
// one of the hosts is chosen at random.
|
||||
func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
|
||||
func (r *LeastConn) Select(pool HostPool, request *http.Request) *UpstreamHost {
|
||||
var bestHost *UpstreamHost
|
||||
count := 0
|
||||
leastConn := int64(math.MaxInt64)
|
||||
|
@ -86,7 +90,7 @@ type RoundRobin struct {
|
|||
}
|
||||
|
||||
// Select selects an up host from the pool using a round robin ordering scheme.
|
||||
func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
|
||||
func (r *RoundRobin) Select(pool HostPool, request *http.Request) *UpstreamHost {
|
||||
poolLen := uint32(len(pool))
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
@ -100,3 +104,35 @@ func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IPHash is a policy that selects hosts based on hashing the request ip
|
||||
type IPHash struct{}
|
||||
|
||||
func hash(s string) uint32 {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(s))
|
||||
return h.Sum32()
|
||||
}
|
||||
|
||||
// Select selects an up host from the pool using a round robin ordering scheme.
|
||||
func (r *IPHash) Select(pool HostPool, request *http.Request) *UpstreamHost {
|
||||
poolLen := uint32(len(pool))
|
||||
clientIP, _, err := net.SplitHostPort(request.RemoteAddr)
|
||||
if err != nil {
|
||||
clientIP = request.RemoteAddr
|
||||
}
|
||||
hash := hash(clientIP)
|
||||
for {
|
||||
if poolLen == 0 {
|
||||
break
|
||||
}
|
||||
index := hash % poolLen
|
||||
host := pool[index]
|
||||
if host.Available() {
|
||||
return host
|
||||
}
|
||||
pool = append(pool[:index], pool[index+1:]...)
|
||||
poolLen--
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ func TestMain(m *testing.M) {
|
|||
|
||||
type customPolicy struct{}
|
||||
|
||||
func (r *customPolicy) Select(pool HostPool) *UpstreamHost {
|
||||
func (r *customPolicy) Select(pool HostPool, request *http.Request) *UpstreamHost {
|
||||
return pool[0]
|
||||
}
|
||||
|
||||
|
@ -43,37 +43,39 @@ func testPool() HostPool {
|
|||
func TestRoundRobinPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
rrPolicy := &RoundRobin{}
|
||||
h := rrPolicy.Select(pool)
|
||||
request, _ := http.NewRequest("GET", "/", nil)
|
||||
|
||||
h := rrPolicy.Select(pool, request)
|
||||
// First selected host is 1, because counter starts at 0
|
||||
// and increments before host is selected
|
||||
if h != pool[1] {
|
||||
t.Error("Expected first round robin host to be second host in the pool.")
|
||||
}
|
||||
h = rrPolicy.Select(pool)
|
||||
h = rrPolicy.Select(pool, request)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected second round robin host to be third host in the pool.")
|
||||
}
|
||||
h = rrPolicy.Select(pool)
|
||||
h = rrPolicy.Select(pool, request)
|
||||
if h != pool[0] {
|
||||
t.Error("Expected third round robin host to be first host in the pool.")
|
||||
}
|
||||
// mark host as down
|
||||
pool[1].Unhealthy = true
|
||||
h = rrPolicy.Select(pool)
|
||||
h = rrPolicy.Select(pool, request)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected to skip down host.")
|
||||
}
|
||||
// mark host as up
|
||||
pool[1].Unhealthy = false
|
||||
|
||||
h = rrPolicy.Select(pool)
|
||||
h = rrPolicy.Select(pool, request)
|
||||
if h == pool[2] {
|
||||
t.Error("Expected to balance evenly among healthy hosts")
|
||||
}
|
||||
// mark host as full
|
||||
pool[1].Conns = 1
|
||||
pool[1].MaxConns = 1
|
||||
h = rrPolicy.Select(pool)
|
||||
h = rrPolicy.Select(pool, request)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected to skip full host.")
|
||||
}
|
||||
|
@ -82,14 +84,16 @@ func TestRoundRobinPolicy(t *testing.T) {
|
|||
func TestLeastConnPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
lcPolicy := &LeastConn{}
|
||||
request, _ := http.NewRequest("GET", "/", nil)
|
||||
|
||||
pool[0].Conns = 10
|
||||
pool[1].Conns = 10
|
||||
h := lcPolicy.Select(pool)
|
||||
h := lcPolicy.Select(pool, request)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected least connection host to be third host.")
|
||||
}
|
||||
pool[2].Conns = 100
|
||||
h = lcPolicy.Select(pool)
|
||||
h = lcPolicy.Select(pool, request)
|
||||
if h != pool[0] && h != pool[1] {
|
||||
t.Error("Expected least connection host to be first or second host.")
|
||||
}
|
||||
|
@ -98,8 +102,127 @@ func TestLeastConnPolicy(t *testing.T) {
|
|||
func TestCustomPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
customPolicy := &customPolicy{}
|
||||
h := customPolicy.Select(pool)
|
||||
request, _ := http.NewRequest("GET", "/", nil)
|
||||
|
||||
h := customPolicy.Select(pool, request)
|
||||
if h != pool[0] {
|
||||
t.Error("Expected custom policy host to be the first host.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPHashPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
ipHash := &IPHash{}
|
||||
request, _ := http.NewRequest("GET", "/", nil)
|
||||
// We should be able to predict where every request is routed.
|
||||
request.RemoteAddr = "172.0.0.1:80"
|
||||
h := ipHash.Select(pool, request)
|
||||
if h != pool[1] {
|
||||
t.Error("Expected ip hash policy host to be the second host.")
|
||||
}
|
||||
request.RemoteAddr = "172.0.0.2:80"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[1] {
|
||||
t.Error("Expected ip hash policy host to be the second host.")
|
||||
}
|
||||
request.RemoteAddr = "172.0.0.3:80"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected ip hash policy host to be the third host.")
|
||||
}
|
||||
request.RemoteAddr = "172.0.0.4:80"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[1] {
|
||||
t.Error("Expected ip hash policy host to be the second host.")
|
||||
}
|
||||
|
||||
// we should get the same results without a port
|
||||
request.RemoteAddr = "172.0.0.1"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[1] {
|
||||
t.Error("Expected ip hash policy host to be the second host.")
|
||||
}
|
||||
request.RemoteAddr = "172.0.0.2"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[1] {
|
||||
t.Error("Expected ip hash policy host to be the second host.")
|
||||
}
|
||||
request.RemoteAddr = "172.0.0.3"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[2] {
|
||||
t.Error("Expected ip hash policy host to be the third host.")
|
||||
}
|
||||
request.RemoteAddr = "172.0.0.4"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[1] {
|
||||
t.Error("Expected ip hash policy host to be the second host.")
|
||||
}
|
||||
|
||||
// we should get a healthy host if the original host is unhealthy and a
|
||||
// healthy host is available
|
||||
request.RemoteAddr = "172.0.0.1"
|
||||
pool[1].Unhealthy = true
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[0] {
|
||||
t.Error("Expected ip hash policy host to be the first host.")
|
||||
}
|
||||
|
||||
request.RemoteAddr = "172.0.0.2"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[1] {
|
||||
t.Error("Expected ip hash policy host to be the second host.")
|
||||
}
|
||||
pool[1].Unhealthy = false
|
||||
|
||||
request.RemoteAddr = "172.0.0.3"
|
||||
pool[2].Unhealthy = true
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[0] {
|
||||
t.Error("Expected ip hash policy host to be the first host.")
|
||||
}
|
||||
request.RemoteAddr = "172.0.0.4"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[0] {
|
||||
t.Error("Expected ip hash policy host to be the first host.")
|
||||
}
|
||||
|
||||
// We should be able to resize the host pool and still be able to predict
|
||||
// where a request will be routed with the same IP's used above
|
||||
pool = []*UpstreamHost{
|
||||
{
|
||||
Name: workableServer.URL, // this should resolve (healthcheck test)
|
||||
},
|
||||
{
|
||||
Name: "http://localhost:99998", // this shouldn't
|
||||
},
|
||||
}
|
||||
pool = HostPool(pool)
|
||||
request.RemoteAddr = "172.0.0.1:80"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[0] {
|
||||
t.Error("Expected ip hash policy host to be the first host.")
|
||||
}
|
||||
request.RemoteAddr = "172.0.0.2:80"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[1] {
|
||||
t.Error("Expected ip hash policy host to be the second host.")
|
||||
}
|
||||
request.RemoteAddr = "172.0.0.3:80"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[0] {
|
||||
t.Error("Expected ip hash policy host to be the first host.")
|
||||
}
|
||||
request.RemoteAddr = "172.0.0.4:80"
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != pool[1] {
|
||||
t.Error("Expected ip hash policy host to be the second host.")
|
||||
}
|
||||
|
||||
// We should get nil when there are no healthy hosts
|
||||
pool[0].Unhealthy = true
|
||||
pool[1].Unhealthy = true
|
||||
h = ipHash.Select(pool, request)
|
||||
if h != nil {
|
||||
t.Error("Expected ip hash policy host to be nil.")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ type Upstream interface {
|
|||
// The path this upstream host should be routed on
|
||||
From() string
|
||||
// Selects an upstream host to be routed to.
|
||||
Select() *UpstreamHost
|
||||
Select(*http.Request) *UpstreamHost
|
||||
// Checks if subpath is not an ignored path
|
||||
AllowedPath(string) bool
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|||
// hosts until timeout (or until we get a nil host).
|
||||
start := time.Now()
|
||||
for time.Now().Sub(start) < tryDuration {
|
||||
host := upstream.Select()
|
||||
host := upstream.Select(r)
|
||||
if host == nil {
|
||||
return http.StatusBadGateway, errUnreachable
|
||||
}
|
||||
|
|
|
@ -736,7 +736,7 @@ func (u *fakeUpstream) From() string {
|
|||
return u.from
|
||||
}
|
||||
|
||||
func (u *fakeUpstream) Select() *UpstreamHost {
|
||||
func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
|
||||
if u.host == nil {
|
||||
uri, err := url.Parse(u.name)
|
||||
if err != nil {
|
||||
|
@ -781,7 +781,7 @@ func (u *fakeWsUpstream) From() string {
|
|||
return "/"
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) Select() *UpstreamHost {
|
||||
func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
|
||||
uri, _ := url.Parse(u.name)
|
||||
return &UpstreamHost{
|
||||
Name: u.name,
|
||||
|
|
|
@ -346,7 +346,7 @@ func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
|
|||
}
|
||||
}
|
||||
|
||||
func (u *staticUpstream) Select() *UpstreamHost {
|
||||
func (u *staticUpstream) Select(r *http.Request) *UpstreamHost {
|
||||
pool := u.Hosts
|
||||
if len(pool) == 1 {
|
||||
if !pool[0].Available() {
|
||||
|
@ -364,11 +364,10 @@ func (u *staticUpstream) Select() *UpstreamHost {
|
|||
if allUnavailable {
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.Policy == nil {
|
||||
return (&Random{}).Select(pool)
|
||||
return (&Random{}).Select(pool, r)
|
||||
}
|
||||
return u.Policy.Select(pool)
|
||||
return u.Policy.Select(pool, r)
|
||||
}
|
||||
|
||||
func (u *staticUpstream) AllowedPath(requestPath string) bool {
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
)
|
||||
|
||||
func TestNewHost(t *testing.T) {
|
||||
|
@ -72,14 +72,15 @@ func TestSelect(t *testing.T) {
|
|||
FailTimeout: 10 * time.Second,
|
||||
MaxFails: 1,
|
||||
}
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
upstream.Hosts[0].Unhealthy = true
|
||||
upstream.Hosts[1].Unhealthy = true
|
||||
upstream.Hosts[2].Unhealthy = true
|
||||
if h := upstream.Select(); h != nil {
|
||||
if h := upstream.Select(r); h != nil {
|
||||
t.Error("Expected select to return nil as all host are down")
|
||||
}
|
||||
upstream.Hosts[2].Unhealthy = false
|
||||
if h := upstream.Select(); h == nil {
|
||||
if h := upstream.Select(r); h == nil {
|
||||
t.Error("Expected select to not return nil")
|
||||
}
|
||||
upstream.Hosts[0].Conns = 1
|
||||
|
@ -88,11 +89,11 @@ func TestSelect(t *testing.T) {
|
|||
upstream.Hosts[1].MaxConns = 1
|
||||
upstream.Hosts[2].Conns = 1
|
||||
upstream.Hosts[2].MaxConns = 1
|
||||
if h := upstream.Select(); h != nil {
|
||||
if h := upstream.Select(r); h != nil {
|
||||
t.Error("Expected select to return nil as all hosts are full")
|
||||
}
|
||||
upstream.Hosts[2].Conns = 0
|
||||
if h := upstream.Select(); h == nil {
|
||||
if h := upstream.Select(r); h == nil {
|
||||
t.Error("Expected select to not return nil")
|
||||
}
|
||||
}
|
||||
|
@ -188,6 +189,7 @@ func TestParseBlockHealthCheck(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestParseBlock(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
tests := []struct {
|
||||
config string
|
||||
}{
|
||||
|
@ -207,7 +209,7 @@ func TestParseBlock(t *testing.T) {
|
|||
t.Error("Expected no error. Got:", err.Error())
|
||||
}
|
||||
for _, upstream := range upstreams {
|
||||
headers := upstream.Select().UpstreamHeaders
|
||||
headers := upstream.Select(r).UpstreamHeaders
|
||||
|
||||
if _, ok := headers["Host"]; !ok {
|
||||
t.Errorf("Test %d: Could not find the Host header", i+1)
|
||||
|
|
Loading…
Reference in a new issue