reverseproxy: Dynamic upstreams (with SRV and A/AAAA support) (#4470)

* reverseproxy: Begin refactor to enable dynamic upstreams

Streamed here: https://www.youtube.com/watch?v=hj7yzXb11jU

* Implement SRV and A/AAA upstream sources

Also get upstreams at every retry loop iteration instead of just once
before the loop. See #4442.

* Minor tweaks from review

* Limit size of upstreams caches

* Add doc notes deprecating LookupSRV

* Provision dynamic upstreams

Still WIP, preparing to preserve health checker functionality

* Rejigger health checks

Move active health check results into handler-specific Upstreams.

Improve documentation regarding health checks and upstreams.

* Deprecation notice

* Add Caddyfile support, use `caddy.Duration`

* Interface guards

* Implement custom resolvers, add resolvers to http transport Caddyfile

* SRV: fix Caddyfile `name` inline arg, remove proto condition

* Use pointer receiver

* Add debug logs

Co-authored-by: Francis Lavoie <lavofr@gmail.com>
This commit is contained in:
Matt Holt 2022-03-06 17:43:39 -07:00 committed by GitHub
parent c50094fc9d
commit ab0455922a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 1063 additions and 288 deletions

View file

@ -0,0 +1,116 @@
:8884 {
reverse_proxy {
dynamic a foo 9000
}
reverse_proxy {
dynamic a {
name foo
port 9000
refresh 5m
resolvers 8.8.8.8 8.8.4.4
dial_timeout 2s
dial_fallback_delay 300ms
}
}
}
:8885 {
reverse_proxy {
dynamic srv _api._tcp.example.com
}
reverse_proxy {
dynamic srv {
service api
proto tcp
name example.com
refresh 5m
resolvers 8.8.8.8 8.8.4.4
dial_timeout 1s
dial_fallback_delay -1s
}
}
}
----------
{
"apps": {
"http": {
"servers": {
"srv0": {
"listen": [
":8884"
],
"routes": [
{
"handle": [
{
"dynamic_upstreams": {
"name": "foo",
"port": "9000",
"source": "a"
},
"handler": "reverse_proxy"
},
{
"dynamic_upstreams": {
"dial_fallback_delay": 300000000,
"dial_timeout": 2000000000,
"name": "foo",
"port": "9000",
"refresh": 300000000000,
"resolver": {
"addresses": [
"8.8.8.8",
"8.8.4.4"
]
},
"source": "a"
},
"handler": "reverse_proxy"
}
]
}
]
},
"srv1": {
"listen": [
":8885"
],
"routes": [
{
"handle": [
{
"dynamic_upstreams": {
"name": "_api._tcp.example.com",
"source": "srv"
},
"handler": "reverse_proxy"
},
{
"dynamic_upstreams": {
"dial_fallback_delay": -1000000000,
"dial_timeout": 1000000000,
"name": "example.com",
"proto": "tcp",
"refresh": 300000000000,
"resolver": {
"addresses": [
"8.8.8.8",
"8.8.4.4"
]
},
"service": "api",
"source": "srv"
},
"handler": "reverse_proxy"
}
]
}
]
}
}
}
}
}

View file

@ -17,6 +17,7 @@ https://example.com {
dial_fallback_delay 5s
response_header_timeout 8s
expect_continue_timeout 9s
resolvers 8.8.8.8 8.8.4.4
versions h2c 2
compression off
@ -88,6 +89,12 @@ https://example.com {
"max_response_header_size": 30000000,
"protocol": "http",
"read_buffer_size": 10000000,
"resolver": {
"addresses": [
"8.8.8.8",
"8.8.4.4"
]
},
"response_header_timeout": 8000000000,
"versions": [
"h2c",

View file

@ -87,7 +87,7 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er
return false
}
upstream, ok := val.(*upstreamHost)
upstream, ok := val.(*Host)
if !ok {
rangeErr = caddy.APIError{
HTTPStatus: http.StatusInternalServerError,
@ -98,7 +98,6 @@ func (adminUpstreams) handleUpstreams(w http.ResponseWriter, r *http.Request) er
results = append(results, upstreamStatus{
Address: address,
Healthy: !upstream.Unhealthy(),
NumRequests: upstream.NumRequests(),
Fails: upstream.Fails(),
})

View file

@ -53,6 +53,7 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
// reverse_proxy [<matcher>] [<upstreams...>] {
// # upstreams
// to <upstreams...>
// dynamic <name> [...]
//
// # load balancing
// lb_policy <name> [<options...>]
@ -190,6 +191,25 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
}
case "dynamic":
if !d.NextArg() {
return d.ArgErr()
}
if h.DynamicUpstreams != nil {
return d.Err("dynamic upstreams already specified")
}
dynModule := d.Val()
modID := "http.reverse_proxy.upstreams." + dynModule
unm, err := caddyfile.UnmarshalModule(d, modID)
if err != nil {
return err
}
source, ok := unm.(UpstreamSource)
if !ok {
return d.Errf("module %s (%T) is not an UpstreamSource", modID, unm)
}
h.DynamicUpstreamsRaw = caddyconfig.JSONModuleObject(source, "source", dynModule, nil)
case "lb_policy":
if !d.NextArg() {
return d.ArgErr()
@ -749,6 +769,7 @@ func (h *Handler) FinalizeUnmarshalCaddyfile(helper httpcaddyfile.Helper) error
// dial_fallback_delay <duration>
// response_header_timeout <duration>
// expect_continue_timeout <duration>
// resolvers <resolvers...>
// tls
// tls_client_auth <automate_name> | <cert_file> <key_file>
// tls_insecure_skip_verify
@ -839,6 +860,15 @@ func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
h.ExpectContinueTimeout = caddy.Duration(dur)
case "resolvers":
if h.Resolver == nil {
h.Resolver = new(UpstreamResolver)
}
h.Resolver.Addresses = d.RemainingArgs()
if len(h.Resolver.Addresses) == 0 {
return d.Errf("must specify at least one resolver address")
}
case "tls_client_auth":
if h.TLS == nil {
h.TLS = new(TLSConfig)
@ -989,10 +1019,201 @@ func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return nil
}
// UnmarshalCaddyfile deserializes Caddyfile tokens into h.
//
// dynamic srv [<name>] {
// service <service>
// proto <proto>
// name <name>
// refresh <interval>
// resolvers <resolvers...>
// dial_timeout <timeout>
// dial_fallback_delay <timeout>
// }
//
func (u *SRVUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
for d.Next() {
args := d.RemainingArgs()
if len(args) > 1 {
return d.ArgErr()
}
if len(args) > 0 {
u.Name = args[0]
}
for d.NextBlock(0) {
switch d.Val() {
case "service":
if !d.NextArg() {
return d.ArgErr()
}
if u.Service != "" {
return d.Errf("srv service has already been specified")
}
u.Service = d.Val()
case "proto":
if !d.NextArg() {
return d.ArgErr()
}
if u.Proto != "" {
return d.Errf("srv proto has already been specified")
}
u.Proto = d.Val()
case "name":
if !d.NextArg() {
return d.ArgErr()
}
if u.Name != "" {
return d.Errf("srv name has already been specified")
}
u.Name = d.Val()
case "refresh":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("parsing refresh interval duration: %v", err)
}
u.Refresh = caddy.Duration(dur)
case "resolvers":
if u.Resolver == nil {
u.Resolver = new(UpstreamResolver)
}
u.Resolver.Addresses = d.RemainingArgs()
if len(u.Resolver.Addresses) == 0 {
return d.Errf("must specify at least one resolver address")
}
case "dial_timeout":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("bad timeout value '%s': %v", d.Val(), err)
}
u.DialTimeout = caddy.Duration(dur)
case "dial_fallback_delay":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("bad delay value '%s': %v", d.Val(), err)
}
u.FallbackDelay = caddy.Duration(dur)
default:
return d.Errf("unrecognized srv option '%s'", d.Val())
}
}
}
return nil
}
// UnmarshalCaddyfile deserializes Caddyfile tokens into h.
//
// dynamic a [<name> <port] {
// name <name>
// port <port>
// refresh <interval>
// resolvers <resolvers...>
// dial_timeout <timeout>
// dial_fallback_delay <timeout>
// }
//
func (u *AUpstreams) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
for d.Next() {
args := d.RemainingArgs()
if len(args) > 2 {
return d.ArgErr()
}
if len(args) > 0 {
u.Name = args[0]
u.Port = args[1]
}
for d.NextBlock(0) {
switch d.Val() {
case "name":
if !d.NextArg() {
return d.ArgErr()
}
if u.Name != "" {
return d.Errf("a name has already been specified")
}
u.Name = d.Val()
case "port":
if !d.NextArg() {
return d.ArgErr()
}
if u.Port != "" {
return d.Errf("a port has already been specified")
}
u.Port = d.Val()
case "refresh":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("parsing refresh interval duration: %v", err)
}
u.Refresh = caddy.Duration(dur)
case "resolvers":
if u.Resolver == nil {
u.Resolver = new(UpstreamResolver)
}
u.Resolver.Addresses = d.RemainingArgs()
if len(u.Resolver.Addresses) == 0 {
return d.Errf("must specify at least one resolver address")
}
case "dial_timeout":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("bad timeout value '%s': %v", d.Val(), err)
}
u.DialTimeout = caddy.Duration(dur)
case "dial_fallback_delay":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("bad delay value '%s': %v", d.Val(), err)
}
u.FallbackDelay = caddy.Duration(dur)
default:
return d.Errf("unrecognized srv option '%s'", d.Val())
}
}
}
return nil
}
const matcherPrefix = "@"
// Interface guards
var (
_ caddyfile.Unmarshaler = (*Handler)(nil)
_ caddyfile.Unmarshaler = (*HTTPTransport)(nil)
_ caddyfile.Unmarshaler = (*SRVUpstreams)(nil)
_ caddyfile.Unmarshaler = (*AUpstreams)(nil)
)

View file

@ -18,7 +18,6 @@ import (
"context"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
@ -37,12 +36,32 @@ import (
type HealthChecks struct {
// Active health checks run in the background on a timer. To
// minimally enable active health checks, set either path or
// port (or both).
// port (or both). Note that active health check status
// (healthy/unhealthy) is stored per-proxy-handler, not
// globally; this allows different handlers to use different
// criteria to decide what defines a healthy backend.
//
// Active health checks do not run for dynamic upstreams.
Active *ActiveHealthChecks `json:"active,omitempty"`
// Passive health checks monitor proxied requests for errors or timeouts.
// To minimally enable passive health checks, specify at least an empty
// config object.
// config object. Passive health check state is shared (stored globally),
// so a failure from one handler will be counted by all handlers; but
// the tolerances or standards for what defines healthy/unhealthy backends
// is configured per-proxy-handler.
//
// Passive health checks technically do operate on dynamic upstreams,
// but are only effective for very busy proxies where the list of
// upstreams is mostly stable. This is because the shared/global
// state of upstreams is cleaned up when the upstreams are no longer
// used. Since dynamic upstreams are allocated dynamically at each
// request (specifically, each iteration of the proxy loop per request),
// they are also cleaned up after every request. Thus, if there is a
// moment when no requests are actively referring to a particular
// upstream host, the passive health check state will be reset because
// it will be garbage-collected. It is usually better for the dynamic
// upstream module to only return healthy, available backends instead.
Passive *PassiveHealthChecks `json:"passive,omitempty"`
}
@ -50,8 +69,7 @@ type HealthChecks struct {
// health checks (that is, health checks which occur in a
// background goroutine independently).
type ActiveHealthChecks struct {
// The path to use for health checks.
// DEPRECATED: Use 'uri' instead.
// DEPRECATED: Use 'uri' instead. This field will be removed. TODO: remove this field
Path string `json:"path,omitempty"`
// The URI (path and query) to use for health checks
@ -132,7 +150,9 @@ type CircuitBreaker interface {
func (h *Handler) activeHealthChecker() {
defer func() {
if err := recover(); err != nil {
log.Printf("[PANIC] active health checks: %v\n%s", err, debug.Stack())
h.HealthChecks.Active.logger.Error("active health checker panicked",
zap.Any("error", err),
zap.ByteString("stack", debug.Stack()))
}
}()
ticker := time.NewTicker(time.Duration(h.HealthChecks.Active.Interval))
@ -155,7 +175,9 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
go func(upstream *Upstream) {
defer func() {
if err := recover(); err != nil {
log.Printf("[PANIC] active health check: %v\n%s", err, debug.Stack())
h.HealthChecks.Active.logger.Error("active health check panicked",
zap.Any("error", err),
zap.ByteString("stack", debug.Stack()))
}
}()
@ -195,7 +217,7 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
// so use a fake Host value instead; unix sockets are usually local
hostAddr = "localhost"
}
err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: dialAddr}, hostAddr, upstream.Host)
err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: dialAddr}, hostAddr, upstream)
if err != nil {
h.HealthChecks.Active.logger.Error("active health check failed",
zap.String("address", hostAddr),
@ -206,14 +228,14 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
}
}
// doActiveHealthCheck performs a health check to host which
// doActiveHealthCheck performs a health check to upstream which
// can be reached at address hostAddr. The actual address for
// the request will be built according to active health checker
// config. The health status of the host will be updated
// according to whether it passes the health check. An error is
// returned only if the health check fails to occur or if marking
// the host's health status fails.
func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host Host) error {
func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, upstream *Upstream) error {
// create the URL for the request that acts as a health check
scheme := "http"
if ht, ok := h.Transport.(TLSTransport); ok && ht.TLSEnabled() {
@ -269,10 +291,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host H
zap.String("host", hostAddr),
zap.Error(err),
)
_, err2 := host.SetHealthy(false)
if err2 != nil {
return fmt.Errorf("marking unhealthy: %v", err2)
}
upstream.setHealthy(false)
return nil
}
var body io.Reader = resp.Body
@ -292,10 +311,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host H
zap.Int("status_code", resp.StatusCode),
zap.String("host", hostAddr),
)
_, err := host.SetHealthy(false)
if err != nil {
return fmt.Errorf("marking unhealthy: %v", err)
}
upstream.setHealthy(false)
return nil
}
} else if resp.StatusCode < 200 || resp.StatusCode >= 400 {
@ -303,10 +319,7 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host H
zap.Int("status_code", resp.StatusCode),
zap.String("host", hostAddr),
)
_, err := host.SetHealthy(false)
if err != nil {
return fmt.Errorf("marking unhealthy: %v", err)
}
upstream.setHealthy(false)
return nil
}
@ -318,33 +331,21 @@ func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host H
zap.String("host", hostAddr),
zap.Error(err),
)
_, err := host.SetHealthy(false)
if err != nil {
return fmt.Errorf("marking unhealthy: %v", err)
}
upstream.setHealthy(false)
return nil
}
if !h.HealthChecks.Active.bodyRegexp.Match(bodyBytes) {
h.HealthChecks.Active.logger.Info("response body failed expectations",
zap.String("host", hostAddr),
)
_, err := host.SetHealthy(false)
if err != nil {
return fmt.Errorf("marking unhealthy: %v", err)
}
upstream.setHealthy(false)
return nil
}
}
// passed health check parameters, so mark as healthy
swapped, err := host.SetHealthy(true)
if swapped {
h.HealthChecks.Active.logger.Info("host is up",
zap.String("host", hostAddr),
)
}
if err != nil {
return fmt.Errorf("marking healthy: %v", err)
if upstream.setHealthy(true) {
h.HealthChecks.Active.logger.Info("host is up", zap.String("host", hostAddr))
}
return nil
@ -366,7 +367,7 @@ func (h *Handler) countFailure(upstream *Upstream) {
}
// count failure immediately
err := upstream.Host.CountFail(1)
err := upstream.Host.countFail(1)
if err != nil {
h.HealthChecks.Passive.logger.Error("could not count failure",
zap.String("host", upstream.Dial),
@ -375,14 +376,23 @@ func (h *Handler) countFailure(upstream *Upstream) {
}
// forget it later
go func(host Host, failDuration time.Duration) {
go func(host *Host, failDuration time.Duration) {
defer func() {
if err := recover(); err != nil {
log.Printf("[PANIC] health check failure forgetter: %v\n%s", err, debug.Stack())
h.HealthChecks.Passive.logger.Error("passive health check failure forgetter panicked",
zap.Any("error", err),
zap.ByteString("stack", debug.Stack()))
}
}()
time.Sleep(failDuration)
err := host.CountFail(-1)
timer := time.NewTimer(failDuration)
select {
case <-h.ctx.Done():
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
}
err := host.countFail(-1)
if err != nil {
h.HealthChecks.Passive.logger.Error("could not forget failure",
zap.String("host", upstream.Dial),

View file

@ -26,44 +26,14 @@ import (
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
// Host represents a remote host which can be proxied to.
// Its methods must be safe for concurrent use.
type Host interface {
// NumRequests returns the number of requests
// currently in process with the host.
NumRequests() int
// Fails returns the count of recent failures.
Fails() int
// Unhealthy returns true if the backend is unhealthy.
Unhealthy() bool
// CountRequest atomically counts the given number of
// requests as currently in process with the host. The
// count should not go below 0.
CountRequest(int) error
// CountFail atomically counts the given number of
// failures with the host. The count should not go
// below 0.
CountFail(int) error
// SetHealthy atomically marks the host as either
// healthy (true) or unhealthy (false). If the given
// status is the same, this should be a no-op and
// return false. It returns true if the status was
// changed; i.e. if it is now different from before.
SetHealthy(bool) (bool, error)
}
// UpstreamPool is a collection of upstreams.
type UpstreamPool []*Upstream
// Upstream bridges this proxy's configuration to the
// state of the backend host it is correlated with.
// Upstream values must not be copied.
type Upstream struct {
Host `json:"-"`
*Host `json:"-"`
// The [network address](/docs/conventions#network-addresses)
// to dial to connect to the upstream. Must represent precisely
@ -77,6 +47,10 @@ type Upstream struct {
// backends is down. Also be aware of open proxy vulnerabilities.
Dial string `json:"dial,omitempty"`
// DEPRECATED: Use the SRVUpstreams module instead
// (http.reverse_proxy.upstreams.srv). This field will be
// removed in a future version of Caddy. TODO: Remove this field.
//
// If DNS SRV records are used for service discovery with this
// upstream, specify the DNS name for which to look up SRV
// records here, instead of specifying a dial address.
@ -95,6 +69,7 @@ type Upstream struct {
activeHealthCheckPort int
healthCheckPolicy *PassiveHealthChecks
cb CircuitBreaker
unhealthy int32 // accessed atomically; status from active health checker
}
func (u Upstream) String() string {
@ -117,7 +92,7 @@ func (u *Upstream) Available() bool {
// is currently known to be healthy or "up".
// It consults the circuit breaker, if any.
func (u *Upstream) Healthy() bool {
healthy := !u.Host.Unhealthy()
healthy := u.healthy()
if healthy && u.healthCheckPolicy != nil {
healthy = u.Host.Fails() < u.healthCheckPolicy.MaxFails
}
@ -142,7 +117,7 @@ func (u *Upstream) fillDialInfo(r *http.Request) (DialInfo, error) {
var addr caddy.NetworkAddress
if u.LookupSRV != "" {
// perform DNS lookup for SRV records and choose one
// perform DNS lookup for SRV records and choose one - TODO: deprecated
srvName := repl.ReplaceAll(u.LookupSRV, "")
_, records, err := net.DefaultResolver.LookupSRV(r.Context(), "", "", srvName)
if err != nil {
@ -174,59 +149,67 @@ func (u *Upstream) fillDialInfo(r *http.Request) (DialInfo, error) {
}, nil
}
// upstreamHost is the basic, in-memory representation
// of the state of a remote host. It implements the
// Host interface.
type upstreamHost struct {
func (u *Upstream) fillHost() {
host := new(Host)
existingHost, loaded := hosts.LoadOrStore(u.String(), host)
if loaded {
host = existingHost.(*Host)
}
u.Host = host
}
// Host is the basic, in-memory representation of the state of a remote host.
// Its fields are accessed atomically and Host values must not be copied.
type Host struct {
numRequests int64 // must be 64-bit aligned on 32-bit systems (see https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
fails int64
unhealthy int32
}
// NumRequests returns the number of active requests to the upstream.
func (uh *upstreamHost) NumRequests() int {
return int(atomic.LoadInt64(&uh.numRequests))
func (h *Host) NumRequests() int {
return int(atomic.LoadInt64(&h.numRequests))
}
// Fails returns the number of recent failures with the upstream.
func (uh *upstreamHost) Fails() int {
return int(atomic.LoadInt64(&uh.fails))
func (h *Host) Fails() int {
return int(atomic.LoadInt64(&h.fails))
}
// Unhealthy returns whether the upstream is healthy.
func (uh *upstreamHost) Unhealthy() bool {
return atomic.LoadInt32(&uh.unhealthy) == 1
}
// CountRequest mutates the active request count by
// countRequest mutates the active request count by
// delta. It returns an error if the adjustment fails.
func (uh *upstreamHost) CountRequest(delta int) error {
result := atomic.AddInt64(&uh.numRequests, int64(delta))
func (h *Host) countRequest(delta int) error {
result := atomic.AddInt64(&h.numRequests, int64(delta))
if result < 0 {
return fmt.Errorf("count below 0: %d", result)
}
return nil
}
// CountFail mutates the recent failures count by
// countFail mutates the recent failures count by
// delta. It returns an error if the adjustment fails.
func (uh *upstreamHost) CountFail(delta int) error {
result := atomic.AddInt64(&uh.fails, int64(delta))
func (h *Host) countFail(delta int) error {
result := atomic.AddInt64(&h.fails, int64(delta))
if result < 0 {
return fmt.Errorf("count below 0: %d", result)
}
return nil
}
// healthy returns true if the upstream is not actively marked as unhealthy.
// (This returns the status only from the "active" health checks.)
func (u *Upstream) healthy() bool {
return atomic.LoadInt32(&u.unhealthy) == 0
}
// SetHealthy sets the upstream has healthy or unhealthy
// and returns true if the new value is different.
func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) {
// and returns true if the new value is different. This
// sets the status only for the "active" health checks.
func (u *Upstream) setHealthy(healthy bool) bool {
var unhealthy, compare int32 = 1, 0
if healthy {
unhealthy, compare = 0, 1
}
swapped := atomic.CompareAndSwapInt32(&uh.unhealthy, compare, unhealthy)
return swapped, nil
return atomic.CompareAndSwapInt32(&u.unhealthy, compare, unhealthy)
}
// DialInfo contains information needed to dial a

View file

@ -168,15 +168,9 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error)
}
if h.Resolver != nil {
for _, v := range h.Resolver.Addresses {
addr, err := caddy.ParseNetworkAddress(v)
if err != nil {
return nil, err
}
if addr.PortRangeSize() != 1 {
return nil, fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr)
}
h.Resolver.netAddrs = append(h.Resolver.netAddrs, addr)
err := h.Resolver.ParseAddresses()
if err != nil {
return nil, err
}
d := &net.Dialer{
Timeout: time.Duration(h.DialTimeout),
@ -406,18 +400,6 @@ func (t TLSConfig) MakeTLSClientConfig(ctx caddy.Context) (*tls.Config, error) {
return cfg, nil
}
// UpstreamResolver holds the set of addresses of DNS resolvers of
// upstream addresses
type UpstreamResolver struct {
// The addresses of DNS resolvers to use when looking up the addresses of proxy upstreams.
// It accepts [network addresses](/docs/conventions#network-addresses)
// with port range of only 1. If the host is an IP address, it will be dialed directly to resolve the upstream server.
// If the host is not an IP address, the addresses are resolved using the [name resolution convention](https://golang.org/pkg/net/#hdr-Name_Resolution) of the Go standard library.
// If the array contains more than 1 resolver address, one is chosen at random.
Addresses []string `json:"addresses,omitempty"`
netAddrs []caddy.NetworkAddress
}
// KeepAlive holds configuration pertaining to HTTP Keep-Alive.
type KeepAlive struct {
// Whether HTTP Keep-Alive is enabled. Default: true

View file

@ -78,9 +78,20 @@ type Handler struct {
// up or down. Down backends will not be proxied to.
HealthChecks *HealthChecks `json:"health_checks,omitempty"`
// Upstreams is the list of backends to proxy to.
// Upstreams is the static list of backends to proxy to.
Upstreams UpstreamPool `json:"upstreams,omitempty"`
// A module for retrieving the list of upstreams dynamically. Dynamic
// upstreams are retrieved at every iteration of the proxy loop for
// each request (i.e. before every proxy attempt within every request).
// Active health checks do not work on dynamic upstreams, and passive
// health checks are only effective on dynamic upstreams if the proxy
// server is busy enough that concurrent requests to the same backends
// are continuous. Instead of health checks for dynamic upstreams, it
// is recommended that the dynamic upstream module only return available
// backends in the first place.
DynamicUpstreamsRaw json.RawMessage `json:"dynamic_upstreams,omitempty" caddy:"namespace=http.reverse_proxy.upstreams inline_key=source"`
// Adjusts how often to flush the response buffer. By default,
// no periodic flushing is done. A negative value disables
// response buffering, and flushes immediately after each
@ -137,8 +148,9 @@ type Handler struct {
// - `{http.reverse_proxy.header.*}` The headers from the response
HandleResponse []caddyhttp.ResponseHandler `json:"handle_response,omitempty"`
Transport http.RoundTripper `json:"-"`
CB CircuitBreaker `json:"-"`
Transport http.RoundTripper `json:"-"`
CB CircuitBreaker `json:"-"`
DynamicUpstreams UpstreamSource `json:"-"`
// Holds the parsed CIDR ranges from TrustedProxies
trustedProxies []*net.IPNet
@ -166,7 +178,7 @@ func (h *Handler) Provision(ctx caddy.Context) error {
h.ctx = ctx
h.logger = ctx.Logger(h)
// verify SRV compatibility
// verify SRV compatibility - TODO: LookupSRV deprecated; will be removed
for i, v := range h.Upstreams {
if v.LookupSRV == "" {
continue
@ -201,6 +213,13 @@ func (h *Handler) Provision(ctx caddy.Context) error {
}
h.CB = mod.(CircuitBreaker)
}
if h.DynamicUpstreamsRaw != nil {
mod, err := ctx.LoadModule(h, "DynamicUpstreamsRaw")
if err != nil {
return fmt.Errorf("loading upstream source module: %v", err)
}
h.DynamicUpstreams = mod.(UpstreamSource)
}
// parse trusted proxy CIDRs ahead of time
for _, str := range h.TrustedProxies {
@ -270,38 +289,8 @@ func (h *Handler) Provision(ctx caddy.Context) error {
}
// set up upstreams
for _, upstream := range h.Upstreams {
// create or get the host representation for this upstream
var host Host = new(upstreamHost)
existingHost, loaded := hosts.LoadOrStore(upstream.String(), host)
if loaded {
host = existingHost.(Host)
}
upstream.Host = host
// give it the circuit breaker, if any
upstream.cb = h.CB
// if the passive health checker has a non-zero UnhealthyRequestCount
// but the upstream has no MaxRequests set (they are the same thing,
// but the passive health checker is a default value for for upstreams
// without MaxRequests), copy the value into this upstream, since the
// value in the upstream (MaxRequests) is what is used during
// availability checks
if h.HealthChecks != nil && h.HealthChecks.Passive != nil {
h.HealthChecks.Passive.logger = h.logger.Named("health_checker.passive")
if h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
upstream.MaxRequests == 0 {
upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
}
}
// upstreams need independent access to the passive
// health check policy because passive health checks
// run without access to h.
if h.HealthChecks != nil {
upstream.healthCheckPolicy = h.HealthChecks.Passive
}
for _, u := range h.Upstreams {
h.provisionUpstream(u)
}
if h.HealthChecks != nil {
@ -413,79 +402,127 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
repl.Set("http.reverse_proxy.duration", time.Since(start))
}()
// in the proxy loop, each iteration is an attempt to proxy the request,
// and because we may retry some number of times, carry over the error
// from previous tries because of the nuances of load balancing & retries
var proxyErr error
for {
// choose an available upstream
upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, clonedReq, w)
if upstream == nil {
if proxyErr == nil {
proxyErr = fmt.Errorf("no upstreams available")
}
if !h.LoadBalancing.tryAgain(h.ctx, start, proxyErr, clonedReq) {
break
}
continue
}
// the dial address may vary per-request if placeholders are
// used, so perform those replacements here; the resulting
// DialInfo struct should have valid network address syntax
dialInfo, err := upstream.fillDialInfo(clonedReq)
if err != nil {
return statusError(fmt.Errorf("making dial info: %v", err))
}
// attach to the request information about how to dial the upstream;
// this is necessary because the information cannot be sufficiently
// or satisfactorily represented in a URL
caddyhttp.SetVar(r.Context(), dialInfoVarKey, dialInfo)
// set placeholders with information about this upstream
repl.Set("http.reverse_proxy.upstream.address", dialInfo.String())
repl.Set("http.reverse_proxy.upstream.hostport", dialInfo.Address)
repl.Set("http.reverse_proxy.upstream.host", dialInfo.Host)
repl.Set("http.reverse_proxy.upstream.port", dialInfo.Port)
repl.Set("http.reverse_proxy.upstream.requests", upstream.Host.NumRequests())
repl.Set("http.reverse_proxy.upstream.max_requests", upstream.MaxRequests)
repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails())
// mutate request headers according to this upstream;
// because we're in a retry loop, we have to copy
// headers (and the Host value) from the original
// so that each retry is identical to the first
if h.Headers != nil && h.Headers.Request != nil {
clonedReq.Header = make(http.Header)
copyHeader(clonedReq.Header, reqHeader)
clonedReq.Host = reqHost
h.Headers.Request.ApplyToRequest(clonedReq)
}
// proxy the request to that upstream
proxyErr = h.reverseProxy(w, clonedReq, repl, dialInfo, next)
if proxyErr == nil || proxyErr == context.Canceled {
// context.Canceled happens when the downstream client
// cancels the request, which is not our failure
return nil
}
// if the roundtrip was successful, don't retry the request or
// ding the health status of the upstream (an error can still
// occur after the roundtrip if, for example, a response handler
// after the roundtrip returns an error)
if succ, ok := proxyErr.(roundtripSucceeded); ok {
return succ.error
}
// remember this failure (if enabled)
h.countFailure(upstream)
// if we've tried long enough, break
if !h.LoadBalancing.tryAgain(h.ctx, start, proxyErr, clonedReq) {
var done bool
done, proxyErr = h.proxyLoopIteration(clonedReq, w, proxyErr, start, repl, reqHeader, reqHost, next)
if done {
break
}
}
return statusError(proxyErr)
if proxyErr != nil {
return statusError(proxyErr)
}
return nil
}
// proxyLoopIteration implements an iteration of the proxy loop. Despite the enormous amount of local state
// that has to be passed in, we brought this into its own method so that we could run defer more easily.
// It returns true when the loop is done and should break; false otherwise. The error value returned should
// be assigned to the proxyErr value for the next iteration of the loop (or the error handled after break).
func (h *Handler) proxyLoopIteration(r *http.Request, w http.ResponseWriter, proxyErr error, start time.Time,
repl *caddy.Replacer, reqHeader http.Header, reqHost string, next caddyhttp.Handler) (bool, error) {
// get the updated list of upstreams
upstreams := h.Upstreams
if h.DynamicUpstreams != nil {
dUpstreams, err := h.DynamicUpstreams.GetUpstreams(r)
if err != nil {
h.logger.Error("failed getting dynamic upstreams; falling back to static upstreams", zap.Error(err))
} else {
upstreams = dUpstreams
for _, dUp := range dUpstreams {
h.provisionUpstream(dUp)
}
h.logger.Debug("provisioned dynamic upstreams", zap.Int("count", len(dUpstreams)))
defer func() {
// these upstreams are dynamic, so they are only used for this iteration
// of the proxy loop; be sure to let them go away when we're done with them
for _, upstream := range dUpstreams {
_, _ = hosts.Delete(upstream.String())
}
}()
}
}
// choose an available upstream
upstream := h.LoadBalancing.SelectionPolicy.Select(upstreams, r, w)
if upstream == nil {
if proxyErr == nil {
proxyErr = fmt.Errorf("no upstreams available")
}
if !h.LoadBalancing.tryAgain(h.ctx, start, proxyErr, r) {
return true, proxyErr
}
return false, proxyErr
}
// the dial address may vary per-request if placeholders are
// used, so perform those replacements here; the resulting
// DialInfo struct should have valid network address syntax
dialInfo, err := upstream.fillDialInfo(r)
if err != nil {
return true, fmt.Errorf("making dial info: %v", err)
}
h.logger.Debug("selected upstream",
zap.String("dial", dialInfo.Address),
zap.Int("total_upstreams", len(upstreams)))
// attach to the request information about how to dial the upstream;
// this is necessary because the information cannot be sufficiently
// or satisfactorily represented in a URL
caddyhttp.SetVar(r.Context(), dialInfoVarKey, dialInfo)
// set placeholders with information about this upstream
repl.Set("http.reverse_proxy.upstream.address", dialInfo.String())
repl.Set("http.reverse_proxy.upstream.hostport", dialInfo.Address)
repl.Set("http.reverse_proxy.upstream.host", dialInfo.Host)
repl.Set("http.reverse_proxy.upstream.port", dialInfo.Port)
repl.Set("http.reverse_proxy.upstream.requests", upstream.Host.NumRequests())
repl.Set("http.reverse_proxy.upstream.max_requests", upstream.MaxRequests)
repl.Set("http.reverse_proxy.upstream.fails", upstream.Host.Fails())
// mutate request headers according to this upstream;
// because we're in a retry loop, we have to copy
// headers (and the r.Host value) from the original
// so that each retry is identical to the first
if h.Headers != nil && h.Headers.Request != nil {
r.Header = make(http.Header)
copyHeader(r.Header, reqHeader)
r.Host = reqHost
h.Headers.Request.ApplyToRequest(r)
}
// proxy the request to that upstream
proxyErr = h.reverseProxy(w, r, repl, dialInfo, next)
if proxyErr == nil || proxyErr == context.Canceled {
// context.Canceled happens when the downstream client
// cancels the request, which is not our failure
return true, nil
}
// if the roundtrip was successful, don't retry the request or
// ding the health status of the upstream (an error can still
// occur after the roundtrip if, for example, a response handler
// after the roundtrip returns an error)
if succ, ok := proxyErr.(roundtripSucceeded); ok {
return true, succ.error
}
// remember this failure (if enabled)
h.countFailure(upstream)
// if we've tried long enough, break
if !h.LoadBalancing.tryAgain(h.ctx, start, proxyErr, r) {
return true, proxyErr
}
return false, proxyErr
}
// prepareRequest clones req so that it can be safely modified without
@ -651,9 +688,9 @@ func (h Handler) addForwardedHeaders(req *http.Request) error {
// (This method is mostly the beginning of what was borrowed from the net/http/httputil package in the
// Go standard library which was used as the foundation.)
func (h *Handler) reverseProxy(rw http.ResponseWriter, req *http.Request, repl *caddy.Replacer, di DialInfo, next caddyhttp.Handler) error {
_ = di.Upstream.Host.CountRequest(1)
_ = di.Upstream.Host.countRequest(1)
//nolint:errcheck
defer di.Upstream.Host.CountRequest(-1)
defer di.Upstream.Host.countRequest(-1)
// point the request to this upstream
h.directRequest(req, di)
@ -905,6 +942,35 @@ func (Handler) directRequest(req *http.Request, di DialInfo) {
req.URL.Host = reqHost
}
func (h Handler) provisionUpstream(upstream *Upstream) {
// create or get the host representation for this upstream
upstream.fillHost()
// give it the circuit breaker, if any
upstream.cb = h.CB
// if the passive health checker has a non-zero UnhealthyRequestCount
// but the upstream has no MaxRequests set (they are the same thing,
// but the passive health checker is a default value for for upstreams
// without MaxRequests), copy the value into this upstream, since the
// value in the upstream (MaxRequests) is what is used during
// availability checks
if h.HealthChecks != nil && h.HealthChecks.Passive != nil {
h.HealthChecks.Passive.logger = h.logger.Named("health_checker.passive")
if h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
upstream.MaxRequests == 0 {
upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
}
}
// upstreams need independent access to the passive
// health check policy because passive health checks
// run without access to h.
if h.HealthChecks != nil {
upstream.healthCheckPolicy = h.HealthChecks.Passive
}
}
// bufferedBody reads originalBody into a buffer, then returns a reader for the buffer.
// Always close the return value when done with it, just like if it was the original body!
func (h Handler) bufferedBody(originalBody io.ReadCloser) io.ReadCloser {
@ -1085,6 +1151,20 @@ type Selector interface {
Select(UpstreamPool, *http.Request, http.ResponseWriter) *Upstream
}
// UpstreamSource gets the list of upstreams that can be used when
// proxying a request. Returned upstreams will be load balanced and
// health-checked. This should be a very fast function -- instant
// if possible -- and the return value must be as stable as possible.
// In other words, the list of upstreams should ideally not change much
// across successive calls. If the list of upstreams changes or the
// ordering is not stable, load balancing will suffer. This function
// may be called during each retry, multiple times per request, and as
// such, needs to be instantaneous. The returned slice will not be
// modified.
type UpstreamSource interface {
GetUpstreams(*http.Request) ([]*Upstream, error)
}
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the

View file

@ -22,9 +22,9 @@ import (
func testPool() UpstreamPool {
return UpstreamPool{
{Host: new(upstreamHost)},
{Host: new(upstreamHost)},
{Host: new(upstreamHost)},
{Host: new(Host)},
{Host: new(Host)},
{Host: new(Host)},
}
}
@ -48,20 +48,20 @@ func TestRoundRobinPolicy(t *testing.T) {
t.Error("Expected third round robin host to be first host in the pool.")
}
// mark host as down
pool[1].SetHealthy(false)
pool[1].setHealthy(false)
h = rrPolicy.Select(pool, req, nil)
if h != pool[2] {
t.Error("Expected to skip down host.")
}
// mark host as up
pool[1].SetHealthy(true)
pool[1].setHealthy(true)
h = rrPolicy.Select(pool, req, nil)
if h == pool[2] {
t.Error("Expected to balance evenly among healthy hosts")
}
// mark host as full
pool[1].CountRequest(1)
pool[1].countRequest(1)
pool[1].MaxRequests = 1
h = rrPolicy.Select(pool, req, nil)
if h != pool[2] {
@ -74,13 +74,13 @@ func TestLeastConnPolicy(t *testing.T) {
lcPolicy := new(LeastConnSelection)
req, _ := http.NewRequest("GET", "/", nil)
pool[0].CountRequest(10)
pool[1].CountRequest(10)
pool[0].countRequest(10)
pool[1].countRequest(10)
h := lcPolicy.Select(pool, req, nil)
if h != pool[2] {
t.Error("Expected least connection host to be third host.")
}
pool[2].CountRequest(100)
pool[2].countRequest(100)
h = lcPolicy.Select(pool, req, nil)
if h != pool[0] && h != pool[1] {
t.Error("Expected least connection host to be first or second host.")
@ -139,7 +139,7 @@ func TestIPHashPolicy(t *testing.T) {
// we should get a healthy host if the original host is unhealthy and a
// healthy host is available
req.RemoteAddr = "172.0.0.1"
pool[1].SetHealthy(false)
pool[1].setHealthy(false)
h = ipHash.Select(pool, req, nil)
if h != pool[2] {
t.Error("Expected ip hash policy host to be the third host.")
@ -150,10 +150,10 @@ func TestIPHashPolicy(t *testing.T) {
if h != pool[2] {
t.Error("Expected ip hash policy host to be the third host.")
}
pool[1].SetHealthy(true)
pool[1].setHealthy(true)
req.RemoteAddr = "172.0.0.3"
pool[2].SetHealthy(false)
pool[2].setHealthy(false)
h = ipHash.Select(pool, req, nil)
if h != pool[0] {
t.Error("Expected ip hash policy host to be the first host.")
@ -167,8 +167,8 @@ func TestIPHashPolicy(t *testing.T) {
// We should be able to resize the host pool and still be able to predict
// where a req will be routed with the same IP's used above
pool = UpstreamPool{
{Host: new(upstreamHost)},
{Host: new(upstreamHost)},
{Host: new(Host)},
{Host: new(Host)},
}
req.RemoteAddr = "172.0.0.1:80"
h = ipHash.Select(pool, req, nil)
@ -192,8 +192,8 @@ func TestIPHashPolicy(t *testing.T) {
}
// We should get nil when there are no healthy hosts
pool[0].SetHealthy(false)
pool[1].SetHealthy(false)
pool[0].setHealthy(false)
pool[1].setHealthy(false)
h = ipHash.Select(pool, req, nil)
if h != nil {
t.Error("Expected ip hash policy host to be nil.")
@ -201,25 +201,25 @@ func TestIPHashPolicy(t *testing.T) {
// Reproduce #4135
pool = UpstreamPool{
{Host: new(upstreamHost)},
{Host: new(upstreamHost)},
{Host: new(upstreamHost)},
{Host: new(upstreamHost)},
{Host: new(upstreamHost)},
{Host: new(upstreamHost)},
{Host: new(upstreamHost)},
{Host: new(upstreamHost)},
{Host: new(upstreamHost)},
{Host: new(Host)},
{Host: new(Host)},
{Host: new(Host)},
{Host: new(Host)},
{Host: new(Host)},
{Host: new(Host)},
{Host: new(Host)},
{Host: new(Host)},
{Host: new(Host)},
}
pool[0].SetHealthy(false)
pool[1].SetHealthy(false)
pool[2].SetHealthy(false)
pool[3].SetHealthy(false)
pool[4].SetHealthy(false)
pool[5].SetHealthy(false)
pool[6].SetHealthy(false)
pool[7].SetHealthy(false)
pool[8].SetHealthy(true)
pool[0].setHealthy(false)
pool[1].setHealthy(false)
pool[2].setHealthy(false)
pool[3].setHealthy(false)
pool[4].setHealthy(false)
pool[5].setHealthy(false)
pool[6].setHealthy(false)
pool[7].setHealthy(false)
pool[8].setHealthy(true)
// We should get a result back when there is one healthy host left.
h = ipHash.Select(pool, req, nil)
@ -239,7 +239,7 @@ func TestFirstPolicy(t *testing.T) {
t.Error("Expected first policy host to be the first host.")
}
pool[0].SetHealthy(false)
pool[0].setHealthy(false)
h = firstPolicy.Select(pool, req, nil)
if h != pool[1] {
t.Error("Expected first policy host to be the second host.")
@ -256,7 +256,7 @@ func TestURIHashPolicy(t *testing.T) {
t.Error("Expected uri policy host to be the first host.")
}
pool[0].SetHealthy(false)
pool[0].setHealthy(false)
h = uriPolicy.Select(pool, request, nil)
if h != pool[1] {
t.Error("Expected uri policy host to be the first host.")
@ -271,8 +271,8 @@ func TestURIHashPolicy(t *testing.T) {
// We should be able to resize the host pool and still be able to predict
// where a request will be routed with the same URI's used above
pool = UpstreamPool{
{Host: new(upstreamHost)},
{Host: new(upstreamHost)},
{Host: new(Host)},
{Host: new(Host)},
}
request = httptest.NewRequest(http.MethodGet, "/test", nil)
@ -281,7 +281,7 @@ func TestURIHashPolicy(t *testing.T) {
t.Error("Expected uri policy host to be the first host.")
}
pool[0].SetHealthy(false)
pool[0].setHealthy(false)
h = uriPolicy.Select(pool, request, nil)
if h != pool[1] {
t.Error("Expected uri policy host to be the first host.")
@ -293,8 +293,8 @@ func TestURIHashPolicy(t *testing.T) {
t.Error("Expected uri policy host to be the second host.")
}
pool[0].SetHealthy(false)
pool[1].SetHealthy(false)
pool[0].setHealthy(false)
pool[1].setHealthy(false)
h = uriPolicy.Select(pool, request, nil)
if h != nil {
t.Error("Expected uri policy policy host to be nil.")
@ -306,12 +306,12 @@ func TestLeastRequests(t *testing.T) {
pool[0].Dial = "localhost:8080"
pool[1].Dial = "localhost:8081"
pool[2].Dial = "localhost:8082"
pool[0].SetHealthy(true)
pool[1].SetHealthy(true)
pool[2].SetHealthy(true)
pool[0].CountRequest(10)
pool[1].CountRequest(20)
pool[2].CountRequest(30)
pool[0].setHealthy(true)
pool[1].setHealthy(true)
pool[2].setHealthy(true)
pool[0].countRequest(10)
pool[1].countRequest(20)
pool[2].countRequest(30)
result := leastRequests(pool)
@ -329,12 +329,12 @@ func TestRandomChoicePolicy(t *testing.T) {
pool[0].Dial = "localhost:8080"
pool[1].Dial = "localhost:8081"
pool[2].Dial = "localhost:8082"
pool[0].SetHealthy(false)
pool[1].SetHealthy(true)
pool[2].SetHealthy(true)
pool[0].CountRequest(10)
pool[1].CountRequest(20)
pool[2].CountRequest(30)
pool[0].setHealthy(false)
pool[1].setHealthy(true)
pool[2].setHealthy(true)
pool[0].countRequest(10)
pool[1].countRequest(20)
pool[2].countRequest(30)
request := httptest.NewRequest(http.MethodGet, "/test", nil)
randomChoicePolicy := new(RandomChoiceSelection)
@ -357,9 +357,9 @@ func TestCookieHashPolicy(t *testing.T) {
pool[0].Dial = "localhost:8080"
pool[1].Dial = "localhost:8081"
pool[2].Dial = "localhost:8082"
pool[0].SetHealthy(true)
pool[1].SetHealthy(false)
pool[2].SetHealthy(false)
pool[0].setHealthy(true)
pool[1].setHealthy(false)
pool[2].setHealthy(false)
request := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
cookieHashPolicy := new(CookieHashSelection)
@ -374,8 +374,8 @@ func TestCookieHashPolicy(t *testing.T) {
if h != pool[0] {
t.Error("Expected cookieHashPolicy host to be the first only available host.")
}
pool[1].SetHealthy(true)
pool[2].SetHealthy(true)
pool[1].setHealthy(true)
pool[2].setHealthy(true)
request = httptest.NewRequest(http.MethodGet, "/test", nil)
w = httptest.NewRecorder()
request.AddCookie(cookieServer1)
@ -387,7 +387,7 @@ func TestCookieHashPolicy(t *testing.T) {
if len(s) != 0 {
t.Error("Expected cookieHashPolicy to not set a new cookie.")
}
pool[0].SetHealthy(false)
pool[0].setHealthy(false)
request = httptest.NewRequest(http.MethodGet, "/test", nil)
w = httptest.NewRecorder()
request.AddCookie(cookieServer1)

View file

@ -0,0 +1,377 @@
package reverseproxy
import (
"context"
"fmt"
weakrand "math/rand"
"net"
"net/http"
"strconv"
"sync"
"time"
"github.com/caddyserver/caddy/v2"
"go.uber.org/zap"
)
func init() {
caddy.RegisterModule(SRVUpstreams{})
caddy.RegisterModule(AUpstreams{})
}
// SRVUpstreams provides upstreams from SRV lookups.
// The lookup DNS name can be configured either by
// its individual parts (that is, specifying the
// service, protocol, and name separately) to form
// the standard "_service._proto.name" domain, or
// the domain can be specified directly in name by
// leaving service and proto empty. See RFC 2782.
//
// Lookups are cached and refreshed at the configured
// refresh interval.
//
// Returned upstreams are sorted by priority and weight.
type SRVUpstreams struct {
// The service label.
Service string `json:"service,omitempty"`
// The protocol label; either tcp or udp.
Proto string `json:"proto,omitempty"`
// The name label; or, if service and proto are
// empty, the entire domain name to look up.
Name string `json:"name,omitempty"`
// The interval at which to refresh the SRV lookup.
// Results are cached between lookups. Default: 1m
Refresh caddy.Duration `json:"refresh,omitempty"`
// Configures the DNS resolver used to resolve the
// SRV address to SRV records.
Resolver *UpstreamResolver `json:"resolver,omitempty"`
// If Resolver is configured, how long to wait before
// timing out trying to connect to the DNS server.
DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
// If Resolver is configured, how long to wait before
// spawning an RFC 6555 Fast Fallback connection.
// A negative value disables this.
FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
resolver *net.Resolver
logger *zap.Logger
}
// CaddyModule returns the Caddy module information.
func (SRVUpstreams) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "http.reverse_proxy.upstreams.srv",
New: func() caddy.Module { return new(SRVUpstreams) },
}
}
// String returns the RFC 2782 representation of the SRV domain.
func (su SRVUpstreams) String() string {
return fmt.Sprintf("_%s._%s.%s", su.Service, su.Proto, su.Name)
}
func (su *SRVUpstreams) Provision(ctx caddy.Context) error {
su.logger = ctx.Logger(su)
if su.Refresh == 0 {
su.Refresh = caddy.Duration(time.Minute)
}
if su.Resolver != nil {
err := su.Resolver.ParseAddresses()
if err != nil {
return err
}
d := &net.Dialer{
Timeout: time.Duration(su.DialTimeout),
FallbackDelay: time.Duration(su.FallbackDelay),
}
su.resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
//nolint:gosec
addr := su.Resolver.netAddrs[weakrand.Intn(len(su.Resolver.netAddrs))]
return d.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
},
}
}
if su.resolver == nil {
su.resolver = net.DefaultResolver
}
return nil
}
func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
suStr := su.String()
// first, use a cheap read-lock to return a cached result quickly
srvsMu.RLock()
cached := srvs[suStr]
srvsMu.RUnlock()
if cached.isFresh() {
return cached.upstreams, nil
}
// otherwise, obtain a write-lock to update the cached value
srvsMu.Lock()
defer srvsMu.Unlock()
// check to see if it's still stale, since we're now in a different
// lock from when we first checked freshness; another goroutine might
// have refreshed it in the meantime before we re-obtained our lock
cached = srvs[suStr]
if cached.isFresh() {
return cached.upstreams, nil
}
// prepare parameters and perform the SRV lookup
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
service := repl.ReplaceAll(su.Service, "")
proto := repl.ReplaceAll(su.Proto, "")
name := repl.ReplaceAll(su.Name, "")
su.logger.Debug("refreshing SRV upstreams",
zap.String("service", service),
zap.String("proto", proto),
zap.String("name", name))
_, records, err := su.resolver.LookupSRV(r.Context(), service, proto, name)
if err != nil {
// From LookupSRV docs: "If the response contains invalid names, those records are filtered
// out and an error will be returned alongside the the remaining results, if any." Thus, we
// only return an error if no records were also returned.
if len(records) == 0 {
return nil, err
}
su.logger.Warn("SRV records filtered", zap.Error(err))
}
upstreams := make([]*Upstream, len(records))
for i, rec := range records {
su.logger.Debug("discovered SRV record",
zap.String("target", rec.Target),
zap.Uint16("port", rec.Port),
zap.Uint16("priority", rec.Priority),
zap.Uint16("weight", rec.Weight))
addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port)))
upstreams[i] = &Upstream{Dial: addr}
}
// before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
if cached.freshness.IsZero() && len(srvs) >= 100 {
for randomKey := range srvs {
delete(srvs, randomKey)
break
}
}
srvs[suStr] = srvLookup{
srvUpstreams: su,
freshness: time.Now(),
upstreams: upstreams,
}
return upstreams, nil
}
type srvLookup struct {
srvUpstreams SRVUpstreams
freshness time.Time
upstreams []*Upstream
}
func (sl srvLookup) isFresh() bool {
return time.Since(sl.freshness) < time.Duration(sl.srvUpstreams.Refresh)
}
var (
srvs = make(map[string]srvLookup)
srvsMu sync.RWMutex
)
// AUpstreams provides upstreams from A/AAAA lookups.
// Results are cached and refreshed at the configured
// refresh interval.
type AUpstreams struct {
// The domain name to look up.
Name string `json:"name,omitempty"`
// The port to use with the upstreams. Default: 80
Port string `json:"port,omitempty"`
// The interval at which to refresh the A lookup.
// Results are cached between lookups. Default: 1m
Refresh caddy.Duration `json:"refresh,omitempty"`
// Configures the DNS resolver used to resolve the
// domain name to A records.
Resolver *UpstreamResolver `json:"resolver,omitempty"`
// If Resolver is configured, how long to wait before
// timing out trying to connect to the DNS server.
DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
// If Resolver is configured, how long to wait before
// spawning an RFC 6555 Fast Fallback connection.
// A negative value disables this.
FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
resolver *net.Resolver
}
// CaddyModule returns the Caddy module information.
func (AUpstreams) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "http.reverse_proxy.upstreams.a",
New: func() caddy.Module { return new(AUpstreams) },
}
}
func (au AUpstreams) String() string { return au.Name }
func (au *AUpstreams) Provision(_ caddy.Context) error {
if au.Refresh == 0 {
au.Refresh = caddy.Duration(time.Minute)
}
if au.Port == "" {
au.Port = "80"
}
if au.Resolver != nil {
err := au.Resolver.ParseAddresses()
if err != nil {
return err
}
d := &net.Dialer{
Timeout: time.Duration(au.DialTimeout),
FallbackDelay: time.Duration(au.FallbackDelay),
}
au.resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
//nolint:gosec
addr := au.Resolver.netAddrs[weakrand.Intn(len(au.Resolver.netAddrs))]
return d.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
},
}
}
if au.resolver == nil {
au.resolver = net.DefaultResolver
}
return nil
}
func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
auStr := au.String()
// first, use a cheap read-lock to return a cached result quickly
aAaaaMu.RLock()
cached := aAaaa[auStr]
aAaaaMu.RUnlock()
if cached.isFresh() {
return cached.upstreams, nil
}
// otherwise, obtain a write-lock to update the cached value
aAaaaMu.Lock()
defer aAaaaMu.Unlock()
// check to see if it's still stale, since we're now in a different
// lock from when we first checked freshness; another goroutine might
// have refreshed it in the meantime before we re-obtained our lock
cached = aAaaa[auStr]
if cached.isFresh() {
return cached.upstreams, nil
}
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
name := repl.ReplaceAll(au.Name, "")
port := repl.ReplaceAll(au.Port, "")
ips, err := au.resolver.LookupIPAddr(r.Context(), name)
if err != nil {
return nil, err
}
upstreams := make([]*Upstream, len(ips))
for i, ip := range ips {
upstreams[i] = &Upstream{
Dial: net.JoinHostPort(ip.String(), port),
}
}
// before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
if cached.freshness.IsZero() && len(srvs) >= 100 {
for randomKey := range aAaaa {
delete(aAaaa, randomKey)
break
}
}
aAaaa[auStr] = aLookup{
aUpstreams: au,
freshness: time.Now(),
upstreams: upstreams,
}
return upstreams, nil
}
type aLookup struct {
aUpstreams AUpstreams
freshness time.Time
upstreams []*Upstream
}
func (al aLookup) isFresh() bool {
return time.Since(al.freshness) < time.Duration(al.aUpstreams.Refresh)
}
// UpstreamResolver holds the set of addresses of DNS resolvers of
// upstream addresses
type UpstreamResolver struct {
// The addresses of DNS resolvers to use when looking up the addresses of proxy upstreams.
// It accepts [network addresses](/docs/conventions#network-addresses)
// with port range of only 1. If the host is an IP address, it will be dialed directly to resolve the upstream server.
// If the host is not an IP address, the addresses are resolved using the [name resolution convention](https://golang.org/pkg/net/#hdr-Name_Resolution) of the Go standard library.
// If the array contains more than 1 resolver address, one is chosen at random.
Addresses []string `json:"addresses,omitempty"`
netAddrs []caddy.NetworkAddress
}
// ParseAddresses parses all the configured network addresses
// and ensures they're ready to be used.
func (u *UpstreamResolver) ParseAddresses() error {
for _, v := range u.Addresses {
addr, err := caddy.ParseNetworkAddress(v)
if err != nil {
return err
}
if addr.PortRangeSize() != 1 {
return fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr)
}
u.netAddrs = append(u.netAddrs, addr)
}
return nil
}
var (
aAaaa = make(map[string]aLookup)
aAaaaMu sync.RWMutex
)
// Interface guards
var (
_ caddy.Provisioner = (*SRVUpstreams)(nil)
_ UpstreamSource = (*SRVUpstreams)(nil)
_ caddy.Provisioner = (*AUpstreams)(nil)
_ UpstreamSource = (*AUpstreams)(nil)
)