diff --git a/caddyhttp/proxy/policy.go b/caddyhttp/proxy/policy.go index 28c81bd8..f95c5b8f 100644 --- a/caddyhttp/proxy/policy.go +++ b/caddyhttp/proxy/policy.go @@ -18,12 +18,13 @@ type Policy interface { } func init() { - RegisterPolicy("random", func() Policy { return &Random{} }) - RegisterPolicy("least_conn", func() Policy { return &LeastConn{} }) - RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} }) - RegisterPolicy("ip_hash", func() Policy { return &IPHash{} }) - RegisterPolicy("first", func() Policy { return &First{} }) - RegisterPolicy("uri_hash", func() Policy { return &URIHash{} }) + RegisterPolicy("random", func(arg string) Policy { return &Random{} }) + RegisterPolicy("least_conn", func(arg string) Policy { return &LeastConn{} }) + RegisterPolicy("round_robin", func(arg string) Policy { return &RoundRobin{} }) + RegisterPolicy("ip_hash", func(arg string) Policy { return &IPHash{} }) + RegisterPolicy("first", func(arg string) Policy { return &First{} }) + RegisterPolicy("uri_hash", func(arg string) Policy { return &URIHash{} }) + RegisterPolicy("header", func(arg string) Policy { return &Header{arg} }) } // Random is a policy that selects up hosts from a pool at random. @@ -160,3 +161,22 @@ func (r *First) Select(pool HostPool, request *http.Request) *UpstreamHost { } return nil } + +// Header is a policy that selects based on a hash of the given header +type Header struct { + // The name of the request header, the value of which will determine + // how the request is routed + Name string +} + +// Select selects the host based on hashing the header value +func (r *Header) Select(pool HostPool, request *http.Request) *UpstreamHost { + if r.Name == "" { + return nil + } + val := request.Header.Get(r.Name) + if val == "" { + return nil + } + return hostByHashing(pool, val) +} diff --git a/caddyhttp/proxy/policy_test.go b/caddyhttp/proxy/policy_test.go index 5cc7e85c..6acf1e08 100644 --- a/caddyhttp/proxy/policy_test.go +++ b/caddyhttp/proxy/policy_test.go @@ -302,3 +302,42 @@ func TestUriPolicy(t *testing.T) { t.Error("Expected uri policy policy host to be nil.") } } + +func TestHeaderPolicy(t *testing.T) { + pool := testPool() + tests := []struct { + Policy *Header + RequestHeaderName string + RequestHeaderValue string + NilHost bool + HostIndex int + }{ + {&Header{""}, "", "", true, 0}, + {&Header{""}, "Affinity", "somevalue", true, 0}, + {&Header{""}, "Affinity", "", true, 0}, + + {&Header{"Affinity"}, "", "", true, 0}, + {&Header{"Affinity"}, "Affinity", "somevalue", false, 1}, + {&Header{"Affinity"}, "Affinity", "somevalue2", false, 0}, + {&Header{"Affinity"}, "Affinity", "somevalue3", false, 2}, + {&Header{"Affinity"}, "Affinity", "", true, 0}, + } + + for idx, test := range tests { + request, _ := http.NewRequest("GET", "/", nil) + if test.RequestHeaderName != "" { + request.Header.Add(test.RequestHeaderName, test.RequestHeaderValue) + } + + host := test.Policy.Select(pool, request) + if test.NilHost && host != nil { + t.Errorf("%d: Expected host to be nil", idx) + } + if !test.NilHost && host == nil { + t.Errorf("%d: Did not expect host to be nil", idx) + } + if !test.NilHost && host != pool[test.HostIndex] { + t.Errorf("%d: Expected Header policy to be host %d", idx, test.HostIndex) + } + } +} diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go index a33ffcdf..e7cc392b 100644 --- a/caddyhttp/proxy/upstream.go +++ b/caddyhttp/proxy/upstream.go @@ -22,7 +22,7 @@ import ( ) var ( - supportedPolicies = make(map[string]func() Policy) + supportedPolicies = make(map[string]func(string) Policy) ) type staticUpstream struct { @@ -243,7 +243,11 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { if !ok { return c.ArgErr() } - u.Policy = policyCreateFunc() + arg := "" + if c.NextArg() { + arg = c.Val() + } + u.Policy = policyCreateFunc(arg) case "fail_timeout": if !c.NextArg() { return c.ArgErr() @@ -523,7 +527,7 @@ func (u *staticUpstream) Stop() error { } // RegisterPolicy adds a custom policy to the proxy. -func RegisterPolicy(name string, policy func() Policy) { +func RegisterPolicy(name string, policy func(string) Policy) { supportedPolicies[name] = policy } diff --git a/caddyhttp/proxy/upstream_test.go b/caddyhttp/proxy/upstream_test.go index b2773dd4..8d1ef719 100644 --- a/caddyhttp/proxy/upstream_test.go +++ b/caddyhttp/proxy/upstream_test.go @@ -106,7 +106,7 @@ func TestSelect(t *testing.T) { func TestRegisterPolicy(t *testing.T) { name := "custom" customPolicy := &customPolicy{} - RegisterPolicy(name, func() Policy { return customPolicy }) + RegisterPolicy(name, func(string) Policy { return customPolicy }) if _, ok := supportedPolicies[name]; !ok { t.Error("Expected supportedPolicies to have a custom policy.") }