From f8b59e77f83c05da87bd5e3780fb7522b863d462 Mon Sep 17 00:00:00 2001
From: Francis Lavoie <lavofr@gmail.com>
Date: Mon, 3 Apr 2023 23:31:47 -0400
Subject: [PATCH] reverseproxy: Add `query` and `client_ip_hash` lb policies
 (#5468)

---
 .../reverseproxy/selectionpolicies.go         |  85 +++++++
 .../reverseproxy/selectionpolicies_test.go    | 215 ++++++++++++++++++
 2 files changed, 300 insertions(+)

diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go
index 0b7f50cd9..4184df596 100644
--- a/modules/caddyhttp/reverseproxy/selectionpolicies.go
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go
@@ -24,11 +24,13 @@ import (
 	"net"
 	"net/http"
 	"strconv"
+	"strings"
 	"sync/atomic"
 	"time"
 
 	"github.com/caddyserver/caddy/v2"
 	"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
+	"github.com/caddyserver/caddy/v2/modules/caddyhttp"
 )
 
 func init() {
@@ -38,7 +40,9 @@ func init() {
 	caddy.RegisterModule(RoundRobinSelection{})
 	caddy.RegisterModule(FirstSelection{})
 	caddy.RegisterModule(IPHashSelection{})
+	caddy.RegisterModule(ClientIPHashSelection{})
 	caddy.RegisterModule(URIHashSelection{})
+	caddy.RegisterModule(QueryHashSelection{})
 	caddy.RegisterModule(HeaderHashSelection{})
 	caddy.RegisterModule(CookieHashSelection{})
 
@@ -303,6 +307,39 @@ func (r *IPHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
 	return nil
 }
 
+// ClientIPHashSelection is a policy that selects a host
+// based on hashing the client IP of the request, as determined
+// by the HTTP app's trusted proxies settings.
+type ClientIPHashSelection struct{}
+
+// CaddyModule returns the Caddy module information.
+func (ClientIPHashSelection) CaddyModule() caddy.ModuleInfo {
+	return caddy.ModuleInfo{
+		ID:  "http.reverse_proxy.selection_policies.client_ip_hash",
+		New: func() caddy.Module { return new(ClientIPHashSelection) },
+	}
+}
+
+// Select returns an available host, if any.
+func (ClientIPHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
+	address := caddyhttp.GetVar(req.Context(), caddyhttp.ClientIPVarKey).(string)
+	clientIP, _, err := net.SplitHostPort(address)
+	if err != nil {
+		clientIP = address // no port
+	}
+	return hostByHashing(pool, clientIP)
+}
+
+// UnmarshalCaddyfile sets up the module from Caddyfile tokens.
+func (r *ClientIPHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
+	for d.Next() {
+		if d.NextArg() {
+			return d.ArgErr()
+		}
+	}
+	return nil
+}
+
 // URIHashSelection is a policy that selects a
 // host by hashing the request URI.
 type URIHashSelection struct{}
@@ -330,6 +367,52 @@ func (r *URIHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
 	return nil
 }
 
+// QueryHashSelection is a policy that selects
+// a host based on a given request query parameter.
+type QueryHashSelection struct {
+	// The query key whose value is to be hashed and used for upstream selection.
+	Key string `json:"key,omitempty"`
+}
+
+// CaddyModule returns the Caddy module information.
+func (QueryHashSelection) CaddyModule() caddy.ModuleInfo {
+	return caddy.ModuleInfo{
+		ID:  "http.reverse_proxy.selection_policies.query",
+		New: func() caddy.Module { return new(QueryHashSelection) },
+	}
+}
+
+// Select returns an available host, if any.
+func (s QueryHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
+	if s.Key == "" {
+		return nil
+	}
+
+	// Since the query may have multiple values for the same key,
+	// we'll join them to avoid a problem where the user can control
+	// the upstream that the request goes to by sending multiple values
+	// for the same key, when the upstream only considers the first value.
+	// Keep in mind that a client changing the order of the values may
+	// affect which upstream is selected, but this is a semantically
+	// different request, because the order of the values is significant.
+	vals := strings.Join(req.URL.Query()[s.Key], ",")
+	if vals == "" {
+		return RandomSelection{}.Select(pool, req, nil)
+	}
+	return hostByHashing(pool, vals)
+}
+
+// UnmarshalCaddyfile sets up the module from Caddyfile tokens.
+func (s *QueryHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
+	for d.Next() {
+		if !d.NextArg() {
+			return d.ArgErr()
+		}
+		s.Key = d.Val()
+	}
+	return nil
+}
+
 // HeaderHashSelection is a policy that selects
 // a host based on a given request header.
 type HeaderHashSelection struct {
@@ -552,7 +635,9 @@ var (
 	_ Selector = (*RoundRobinSelection)(nil)
 	_ Selector = (*FirstSelection)(nil)
 	_ Selector = (*IPHashSelection)(nil)
+	_ Selector = (*ClientIPHashSelection)(nil)
 	_ Selector = (*URIHashSelection)(nil)
+	_ Selector = (*QueryHashSelection)(nil)
 	_ Selector = (*HeaderHashSelection)(nil)
 	_ Selector = (*CookieHashSelection)(nil)
 
diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
index 546a60d6e..d2b7b3d0b 100644
--- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
+++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go
@@ -15,9 +15,12 @@
 package reverseproxy
 
 import (
+	"context"
 	"net/http"
 	"net/http/httptest"
 	"testing"
+
+	"github.com/caddyserver/caddy/v2/modules/caddyhttp"
 )
 
 func testPool() UpstreamPool {
@@ -229,6 +232,149 @@ func TestIPHashPolicy(t *testing.T) {
 	}
 }
 
+func TestClientIPHashPolicy(t *testing.T) {
+	pool := testPool()
+	ipHash := new(ClientIPHashSelection)
+	req, _ := http.NewRequest("GET", "/", nil)
+	req = req.WithContext(context.WithValue(req.Context(), caddyhttp.VarsCtxKey, make(map[string]any)))
+
+	// We should be able to predict where every request is routed.
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.1:80")
+	h := ipHash.Select(pool, req, nil)
+	if h != pool[1] {
+		t.Error("Expected ip hash policy host to be the second host.")
+	}
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.2:80")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[1] {
+		t.Error("Expected ip hash policy host to be the second host.")
+	}
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.3:80")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[1] {
+		t.Error("Expected ip hash policy host to be the second host.")
+	}
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4:80")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[1] {
+		t.Error("Expected ip hash policy host to be the second host.")
+	}
+
+	// we should get the same results without a port
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.1")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[1] {
+		t.Error("Expected ip hash policy host to be the second host.")
+	}
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.2")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[1] {
+		t.Error("Expected ip hash policy host to be the second host.")
+	}
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.3")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[1] {
+		t.Error("Expected ip hash policy host to be the second host.")
+	}
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[1] {
+		t.Error("Expected ip hash policy host to be the second host.")
+	}
+
+	// we should get a healthy host if the original host is unhealthy and a
+	// healthy host is available
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4")
+	pool[1].setHealthy(false)
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[0] {
+		t.Error("Expected ip hash policy host to be the first host.")
+	}
+
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.2")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[0] {
+		t.Error("Expected ip hash policy host to be the first host.")
+	}
+	pool[1].setHealthy(true)
+
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.3")
+	pool[2].setHealthy(false)
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[1] {
+		t.Error("Expected ip hash policy host to be the second host.")
+	}
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[1] {
+		t.Error("Expected ip hash policy host to be the second host.")
+	}
+
+	// We should be able to resize the host pool and still be able to predict
+	// where a req will be routed with the same IP's used above
+	pool = UpstreamPool{
+		{Host: new(Host), Dial: "0.0.0.2"},
+		{Host: new(Host), Dial: "0.0.0.3"},
+	}
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.1:80")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[0] {
+		t.Error("Expected ip hash policy host to be the first host.")
+	}
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.2:80")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[0] {
+		t.Error("Expected ip hash policy host to be the first host.")
+	}
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.3:80")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[0] {
+		t.Error("Expected ip hash policy host to be the first host.")
+	}
+	caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, "172.0.0.4:80")
+	h = ipHash.Select(pool, req, nil)
+	if h != pool[0] {
+		t.Error("Expected ip hash policy host to be the first host.")
+	}
+
+	// We should get nil when there are no healthy hosts
+	pool[0].setHealthy(false)
+	pool[1].setHealthy(false)
+	h = ipHash.Select(pool, req, nil)
+	if h != nil {
+		t.Error("Expected ip hash policy host to be nil.")
+	}
+
+	// Reproduce #4135
+	pool = UpstreamPool{
+		{Host: new(Host)},
+		{Host: new(Host)},
+		{Host: new(Host)},
+		{Host: new(Host)},
+		{Host: new(Host)},
+		{Host: new(Host)},
+		{Host: new(Host)},
+		{Host: new(Host)},
+		{Host: new(Host)},
+	}
+	pool[0].setHealthy(false)
+	pool[1].setHealthy(false)
+	pool[2].setHealthy(false)
+	pool[3].setHealthy(false)
+	pool[4].setHealthy(false)
+	pool[5].setHealthy(false)
+	pool[6].setHealthy(false)
+	pool[7].setHealthy(false)
+	pool[8].setHealthy(true)
+
+	// We should get a result back when there is one healthy host left.
+	h = ipHash.Select(pool, req, nil)
+	if h == nil {
+		// If it is nil, it means we missed a host even though one is available
+		t.Error("Expected ip hash policy host to not be nil, but it is nil.")
+	}
+}
+
 func TestFirstPolicy(t *testing.T) {
 	pool := testPool()
 	firstPolicy := new(FirstSelection)
@@ -246,6 +392,75 @@ func TestFirstPolicy(t *testing.T) {
 	}
 }
 
+func TestQueryHashPolicy(t *testing.T) {
+	pool := testPool()
+	queryPolicy := QueryHashSelection{Key: "foo"}
+
+	request := httptest.NewRequest(http.MethodGet, "/?foo=1", nil)
+	h := queryPolicy.Select(pool, request, nil)
+	if h != pool[0] {
+		t.Error("Expected query policy host to be the first host.")
+	}
+
+	request = httptest.NewRequest(http.MethodGet, "/?foo=100000", nil)
+	h = queryPolicy.Select(pool, request, nil)
+	if h != pool[0] {
+		t.Error("Expected query policy host to be the first host.")
+	}
+
+	request = httptest.NewRequest(http.MethodGet, "/?foo=1", nil)
+	pool[0].setHealthy(false)
+	h = queryPolicy.Select(pool, request, nil)
+	if h != pool[1] {
+		t.Error("Expected query policy host to be the second host.")
+	}
+
+	request = httptest.NewRequest(http.MethodGet, "/?foo=100000", nil)
+	h = queryPolicy.Select(pool, request, nil)
+	if h != pool[2] {
+		t.Error("Expected query policy host to be the third host.")
+	}
+
+	// We should be able to resize the host pool and still be able to predict
+	// where a request will be routed with the same query used above
+	pool = UpstreamPool{
+		{Host: new(Host)},
+		{Host: new(Host)},
+	}
+
+	request = httptest.NewRequest(http.MethodGet, "/?foo=1", nil)
+	h = queryPolicy.Select(pool, request, nil)
+	if h != pool[0] {
+		t.Error("Expected query policy host to be the first host.")
+	}
+
+	pool[0].setHealthy(false)
+	h = queryPolicy.Select(pool, request, nil)
+	if h != pool[1] {
+		t.Error("Expected query policy host to be the second host.")
+	}
+
+	request = httptest.NewRequest(http.MethodGet, "/?foo=4", nil)
+	h = queryPolicy.Select(pool, request, nil)
+	if h != pool[1] {
+		t.Error("Expected query policy host to be the second host.")
+	}
+
+	pool[0].setHealthy(false)
+	pool[1].setHealthy(false)
+	h = queryPolicy.Select(pool, request, nil)
+	if h != nil {
+		t.Error("Expected query policy policy host to be nil.")
+	}
+
+	request = httptest.NewRequest(http.MethodGet, "/?foo=aa11&foo=bb22", nil)
+	pool = testPool()
+	h = queryPolicy.Select(pool, request, nil)
+	if h != pool[0] {
+		t.Error("Expected query policy host to be the first host.")
+	}
+}
+
 func TestURIHashPolicy(t *testing.T) {
 	pool := testPool()
 	uriPolicy := new(URIHashSelection)