mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-13 22:36:27 +03:00
reverseproxy: Add fallback
for some policies, instead of always random (#5488)
This commit is contained in:
parent
cdce452edc
commit
48598e1f2a
2 changed files with 239 additions and 41 deletions
|
@ -18,6 +18,7 @@ import (
|
|||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
weakrand "math/rand"
|
||||
|
@ -29,6 +30,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig"
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
)
|
||||
|
@ -372,6 +374,10 @@ func (r *URIHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
|||
type QueryHashSelection struct {
|
||||
// The query key whose value is to be hashed and used for upstream selection.
|
||||
Key string `json:"key,omitempty"`
|
||||
|
||||
// The fallback policy to use if the query key is not present. Defaults to `random`.
|
||||
FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"`
|
||||
fallback Selector
|
||||
}
|
||||
|
||||
// CaddyModule returns the Caddy module information.
|
||||
|
@ -382,12 +388,24 @@ func (QueryHashSelection) CaddyModule() caddy.ModuleInfo {
|
|||
}
|
||||
}
|
||||
|
||||
// Provision sets up the module.
|
||||
func (s *QueryHashSelection) Provision(ctx caddy.Context) error {
|
||||
if s.Key == "" {
|
||||
return fmt.Errorf("query key is required")
|
||||
}
|
||||
if s.FallbackRaw == nil {
|
||||
s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil)
|
||||
}
|
||||
mod, err := ctx.LoadModule(s, "FallbackRaw")
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading fallback selection policy: %s", err)
|
||||
}
|
||||
s.fallback = mod.(Selector)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Select returns an available host, if any.
|
||||
func (s QueryHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
|
||||
if s.Key == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Since the query may have multiple values for the same key,
|
||||
// we'll join them to avoid a problem where the user can control
|
||||
// the upstream that the request goes to by sending multiple values
|
||||
|
@ -397,7 +415,7 @@ func (s QueryHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.
|
|||
// different request, because the order of the values is significant.
|
||||
vals := strings.Join(req.URL.Query()[s.Key], ",")
|
||||
if vals == "" {
|
||||
return RandomSelection{}.Select(pool, req, nil)
|
||||
return s.fallback.Select(pool, req, nil)
|
||||
}
|
||||
return hostByHashing(pool, vals)
|
||||
}
|
||||
|
@ -410,6 +428,24 @@ func (s *QueryHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
|||
}
|
||||
s.Key = d.Val()
|
||||
}
|
||||
for nesting := d.Nesting(); d.NextBlock(nesting); {
|
||||
switch d.Val() {
|
||||
case "fallback":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
if s.FallbackRaw != nil {
|
||||
return d.Err("fallback selection policy already specified")
|
||||
}
|
||||
mod, err := loadFallbackPolicy(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.FallbackRaw = mod
|
||||
default:
|
||||
return d.Errf("unrecognized option '%s'", d.Val())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -418,6 +454,10 @@ func (s *QueryHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
|||
type HeaderHashSelection struct {
|
||||
// The HTTP header field whose value is to be hashed and used for upstream selection.
|
||||
Field string `json:"field,omitempty"`
|
||||
|
||||
// The fallback policy to use if the header is not present. Defaults to `random`.
|
||||
FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"`
|
||||
fallback Selector
|
||||
}
|
||||
|
||||
// CaddyModule returns the Caddy module information.
|
||||
|
@ -428,12 +468,24 @@ func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo {
|
|||
}
|
||||
}
|
||||
|
||||
// Provision sets up the module.
|
||||
func (s *HeaderHashSelection) Provision(ctx caddy.Context) error {
|
||||
if s.Field == "" {
|
||||
return fmt.Errorf("header field is required")
|
||||
}
|
||||
if s.FallbackRaw == nil {
|
||||
s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil)
|
||||
}
|
||||
mod, err := ctx.LoadModule(s, "FallbackRaw")
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading fallback selection policy: %s", err)
|
||||
}
|
||||
s.fallback = mod.(Selector)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Select returns an available host, if any.
|
||||
func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http.ResponseWriter) *Upstream {
|
||||
if s.Field == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// The Host header should be obtained from the req.Host field
|
||||
// since net/http removes it from the header map.
|
||||
if s.Field == "Host" && req.Host != "" {
|
||||
|
@ -442,7 +494,7 @@ func (s HeaderHashSelection) Select(pool UpstreamPool, req *http.Request, _ http
|
|||
|
||||
val := req.Header.Get(s.Field)
|
||||
if val == "" {
|
||||
return RandomSelection{}.Select(pool, req, nil)
|
||||
return s.fallback.Select(pool, req, nil)
|
||||
}
|
||||
return hostByHashing(pool, val)
|
||||
}
|
||||
|
@ -455,6 +507,24 @@ func (s *HeaderHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
|||
}
|
||||
s.Field = d.Val()
|
||||
}
|
||||
for nesting := d.Nesting(); d.NextBlock(nesting); {
|
||||
switch d.Val() {
|
||||
case "fallback":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
if s.FallbackRaw != nil {
|
||||
return d.Err("fallback selection policy already specified")
|
||||
}
|
||||
mod, err := loadFallbackPolicy(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.FallbackRaw = mod
|
||||
default:
|
||||
return d.Errf("unrecognized option '%s'", d.Val())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -465,6 +535,10 @@ type CookieHashSelection struct {
|
|||
Name string `json:"name,omitempty"`
|
||||
// Secret to hash (Hmac256) chosen upstream in cookie
|
||||
Secret string `json:"secret,omitempty"`
|
||||
|
||||
// The fallback policy to use if the cookie is not present. Defaults to `random`.
|
||||
FallbackRaw json.RawMessage `json:"fallback,omitempty" caddy:"namespace=http.reverse_proxy.selection_policies inline_key=policy"`
|
||||
fallback Selector
|
||||
}
|
||||
|
||||
// CaddyModule returns the Caddy module information.
|
||||
|
@ -475,15 +549,48 @@ func (CookieHashSelection) CaddyModule() caddy.ModuleInfo {
|
|||
}
|
||||
}
|
||||
|
||||
// Select returns an available host, if any.
|
||||
func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream {
|
||||
// Provision sets up the module.
|
||||
func (s *CookieHashSelection) Provision(ctx caddy.Context) error {
|
||||
if s.Name == "" {
|
||||
s.Name = "lb"
|
||||
}
|
||||
if s.FallbackRaw == nil {
|
||||
s.FallbackRaw = caddyconfig.JSONModuleObject(RandomSelection{}, "policy", "random", nil)
|
||||
}
|
||||
mod, err := ctx.LoadModule(s, "FallbackRaw")
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading fallback selection policy: %s", err)
|
||||
}
|
||||
s.fallback = mod.(Selector)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Select returns an available host, if any.
|
||||
func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http.ResponseWriter) *Upstream {
|
||||
// selects a new Host using the fallback policy (typically random)
|
||||
// and write a sticky session cookie to the response.
|
||||
selectNewHost := func() *Upstream {
|
||||
upstream := s.fallback.Select(pool, req, w)
|
||||
if upstream == nil {
|
||||
return nil
|
||||
}
|
||||
sha, err := hashCookie(s.Secret, upstream.Dial)
|
||||
if err != nil {
|
||||
return upstream
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: s.Name,
|
||||
Value: sha,
|
||||
Path: "/",
|
||||
Secure: false,
|
||||
})
|
||||
return upstream
|
||||
}
|
||||
|
||||
cookie, err := req.Cookie(s.Name)
|
||||
// If there's no cookie, select new random host
|
||||
// If there's no cookie, select a host using the fallback policy
|
||||
if err != nil || cookie == nil {
|
||||
return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name)
|
||||
return selectNewHost()
|
||||
}
|
||||
// If the cookie is present, loop over the available upstreams until we find a match
|
||||
cookieValue := cookie.Value
|
||||
|
@ -496,13 +603,15 @@ func (s CookieHashSelection) Select(pool UpstreamPool, req *http.Request, w http
|
|||
return upstream
|
||||
}
|
||||
}
|
||||
// If there is no matching host, select new random host
|
||||
return selectNewHostWithCookieHashSelection(pool, w, s.Secret, s.Name)
|
||||
// If there is no matching host, select a host using the fallback policy
|
||||
return selectNewHost()
|
||||
}
|
||||
|
||||
// UnmarshalCaddyfile sets up the module from Caddyfile tokens. Syntax:
|
||||
//
|
||||
// lb_policy cookie [<name> [<secret>]]
|
||||
// lb_policy cookie [<name> [<secret>]] {
|
||||
// fallback <policy>
|
||||
// }
|
||||
//
|
||||
// By default name is `lb`
|
||||
func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
|
@ -517,22 +626,25 @@ func (s *CookieHashSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
|||
default:
|
||||
return d.ArgErr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Select a new Host randomly and add a sticky session cookie
|
||||
func selectNewHostWithCookieHashSelection(pool []*Upstream, w http.ResponseWriter, cookieSecret string, cookieName string) *Upstream {
|
||||
randomHost := selectRandomHost(pool)
|
||||
|
||||
if randomHost != nil {
|
||||
// Hash (HMAC with some key for privacy) the upstream.Dial string as the cookie value
|
||||
sha, err := hashCookie(cookieSecret, randomHost.Dial)
|
||||
if err == nil {
|
||||
// write the cookie.
|
||||
http.SetCookie(w, &http.Cookie{Name: cookieName, Value: sha, Path: "/", Secure: false})
|
||||
for nesting := d.Nesting(); d.NextBlock(nesting); {
|
||||
switch d.Val() {
|
||||
case "fallback":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
if s.FallbackRaw != nil {
|
||||
return d.Err("fallback selection policy already specified")
|
||||
}
|
||||
mod, err := loadFallbackPolicy(d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.FallbackRaw = mod
|
||||
default:
|
||||
return d.Errf("unrecognized option '%s'", d.Val())
|
||||
}
|
||||
}
|
||||
return randomHost
|
||||
return nil
|
||||
}
|
||||
|
||||
// hashCookie hashes (HMAC 256) some data with the secret
|
||||
|
@ -627,6 +739,20 @@ func hash(s string) uint32 {
|
|||
return h.Sum32()
|
||||
}
|
||||
|
||||
func loadFallbackPolicy(d *caddyfile.Dispenser) (json.RawMessage, error) {
|
||||
name := d.Val()
|
||||
modID := "http.reverse_proxy.selection_policies." + name
|
||||
unm, err := caddyfile.UnmarshalModule(d, modID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sel, ok := unm.(Selector)
|
||||
if !ok {
|
||||
return nil, d.Errf("module %s (%T) is not a reverseproxy.Selector", modID, unm)
|
||||
}
|
||||
return caddyconfig.JSONModuleObject(sel, "policy", name, nil), nil
|
||||
}
|
||||
|
||||
// Interface guards
|
||||
var (
|
||||
_ Selector = (*RandomSelection)(nil)
|
||||
|
|
|
@ -20,6 +20,8 @@ import (
|
|||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
)
|
||||
|
||||
|
@ -33,7 +35,7 @@ func testPool() UpstreamPool {
|
|||
|
||||
func TestRoundRobinPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
rrPolicy := new(RoundRobinSelection)
|
||||
rrPolicy := RoundRobinSelection{}
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
|
||||
h := rrPolicy.Select(pool, req, nil)
|
||||
|
@ -74,7 +76,7 @@ func TestRoundRobinPolicy(t *testing.T) {
|
|||
|
||||
func TestLeastConnPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
lcPolicy := new(LeastConnSelection)
|
||||
lcPolicy := LeastConnSelection{}
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
|
||||
pool[0].countRequest(10)
|
||||
|
@ -92,7 +94,7 @@ func TestLeastConnPolicy(t *testing.T) {
|
|||
|
||||
func TestIPHashPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
ipHash := new(IPHashSelection)
|
||||
ipHash := IPHashSelection{}
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
|
||||
// We should be able to predict where every request is routed.
|
||||
|
@ -234,7 +236,7 @@ func TestIPHashPolicy(t *testing.T) {
|
|||
|
||||
func TestClientIPHashPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
ipHash := new(ClientIPHashSelection)
|
||||
ipHash := ClientIPHashSelection{}
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), caddyhttp.VarsCtxKey, make(map[string]any)))
|
||||
|
||||
|
@ -377,7 +379,7 @@ func TestClientIPHashPolicy(t *testing.T) {
|
|||
|
||||
func TestFirstPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
firstPolicy := new(FirstSelection)
|
||||
firstPolicy := FirstSelection{}
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := firstPolicy.Select(pool, req, nil)
|
||||
|
@ -393,8 +395,15 @@ func TestFirstPolicy(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestQueryHashPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
queryPolicy := QueryHashSelection{Key: "foo"}
|
||||
if err := queryPolicy.Provision(ctx); err != nil {
|
||||
t.Errorf("Provision error: %v", err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
pool := testPool()
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/?foo=1", nil)
|
||||
h := queryPolicy.Select(pool, request, nil)
|
||||
|
@ -463,7 +472,7 @@ func TestQueryHashPolicy(t *testing.T) {
|
|||
|
||||
func TestURIHashPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
uriPolicy := new(URIHashSelection)
|
||||
uriPolicy := URIHashSelection{}
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
h := uriPolicy.Select(pool, request, nil)
|
||||
|
@ -552,8 +561,7 @@ func TestRandomChoicePolicy(t *testing.T) {
|
|||
pool[2].countRequest(30)
|
||||
|
||||
request := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
randomChoicePolicy := new(RandomChoiceSelection)
|
||||
randomChoicePolicy.Choose = 2
|
||||
randomChoicePolicy := RandomChoiceSelection{Choose: 2}
|
||||
|
||||
h := randomChoicePolicy.Select(pool, request, nil)
|
||||
|
||||
|
@ -568,6 +576,14 @@ func TestRandomChoicePolicy(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCookieHashPolicy(t *testing.T) {
|
||||
ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
cookieHashPolicy := CookieHashSelection{}
|
||||
if err := cookieHashPolicy.Provision(ctx); err != nil {
|
||||
t.Errorf("Provision error: %v", err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
pool := testPool()
|
||||
pool[0].Dial = "localhost:8080"
|
||||
pool[1].Dial = "localhost:8081"
|
||||
|
@ -577,7 +593,7 @@ func TestCookieHashPolicy(t *testing.T) {
|
|||
pool[2].setHealthy(false)
|
||||
request := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
cookieHashPolicy := new(CookieHashSelection)
|
||||
|
||||
h := cookieHashPolicy.Select(pool, request, w)
|
||||
cookieServer1 := w.Result().Cookies()[0]
|
||||
if cookieServer1 == nil {
|
||||
|
@ -614,3 +630,59 @@ func TestCookieHashPolicy(t *testing.T) {
|
|||
t.Error("Expected cookieHashPolicy to set a new cookie.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookieHashPolicyWithFirstFallback(t *testing.T) {
|
||||
ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
cookieHashPolicy := CookieHashSelection{
|
||||
FallbackRaw: caddyconfig.JSONModuleObject(FirstSelection{}, "policy", "first", nil),
|
||||
}
|
||||
if err := cookieHashPolicy.Provision(ctx); err != nil {
|
||||
t.Errorf("Provision error: %v", err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
pool := testPool()
|
||||
pool[0].Dial = "localhost:8080"
|
||||
pool[1].Dial = "localhost:8081"
|
||||
pool[2].Dial = "localhost:8082"
|
||||
pool[0].setHealthy(true)
|
||||
pool[1].setHealthy(true)
|
||||
pool[2].setHealthy(true)
|
||||
request := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h := cookieHashPolicy.Select(pool, request, w)
|
||||
cookieServer1 := w.Result().Cookies()[0]
|
||||
if cookieServer1 == nil {
|
||||
t.Fatal("cookieHashPolicy should set a cookie")
|
||||
}
|
||||
if cookieServer1.Name != "lb" {
|
||||
t.Error("cookieHashPolicy should set a cookie with name lb")
|
||||
}
|
||||
if h != pool[0] {
|
||||
t.Errorf("Expected cookieHashPolicy host to be the first only available host, got %s", h)
|
||||
}
|
||||
request = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w = httptest.NewRecorder()
|
||||
request.AddCookie(cookieServer1)
|
||||
h = cookieHashPolicy.Select(pool, request, w)
|
||||
if h != pool[0] {
|
||||
t.Errorf("Expected cookieHashPolicy host to stick to the first host (matching cookie), got %s", h)
|
||||
}
|
||||
s := w.Result().Cookies()
|
||||
if len(s) != 0 {
|
||||
t.Error("Expected cookieHashPolicy to not set a new cookie.")
|
||||
}
|
||||
pool[0].setHealthy(false)
|
||||
request = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w = httptest.NewRecorder()
|
||||
request.AddCookie(cookieServer1)
|
||||
h = cookieHashPolicy.Select(pool, request, w)
|
||||
if h != pool[1] {
|
||||
t.Errorf("Expected cookieHashPolicy to select the next first available host, got %s", h)
|
||||
}
|
||||
if w.Result().Cookies() == nil {
|
||||
t.Error("Expected cookieHashPolicy to set a new cookie.")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue