Reconcile upstream dial addresses and request host/URL information

My goodness that was complicated

Blessed be request.Context

Sort of
This commit is contained in:
Matthew Holt 2019-09-05 13:14:39 -06:00
parent a60d54dbfd
commit 0830fbad03
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5
9 changed files with 237 additions and 183 deletions

View file

@ -165,19 +165,19 @@ var (
listenersMu sync.Mutex listenersMu sync.Mutex
) )
// ParseListenAddr parses addr, a string of the form "network/host:port" // ParseNetworkAddress parses addr, a string of the form "network/host:port"
// (with any part optional) into its component parts. Because a port can // (with any part optional) into its component parts. Because a port can
// also be a port range, there may be multiple addresses returned. // also be a port range, there may be multiple addresses returned.
func ParseListenAddr(addr string) (network string, addrs []string, err error) { func ParseNetworkAddress(addr string) (network string, addrs []string, err error) {
var host, port string var host, port string
network, host, port, err = SplitListenAddr(addr) network, host, port, err = SplitNetworkAddress(addr)
if network == "" { if network == "" {
network = "tcp" network = "tcp"
} }
if err != nil { if err != nil {
return return
} }
if network == "unix" { if network == "unix" || network == "unixgram" || network == "unixpacket" {
addrs = []string{host} addrs = []string{host}
return return
} }
@ -204,14 +204,14 @@ func ParseListenAddr(addr string) (network string, addrs []string, err error) {
return return
} }
// SplitListenAddr splits a into its network, host, and port components. // SplitNetworkAddress splits a into its network, host, and port components.
// Note that port may be a port range, or omitted for unix sockets. // Note that port may be a port range, or omitted for unix sockets.
func SplitListenAddr(a string) (network, host, port string, err error) { func SplitNetworkAddress(a string) (network, host, port string, err error) {
if idx := strings.Index(a, "/"); idx >= 0 { if idx := strings.Index(a, "/"); idx >= 0 {
network = strings.ToLower(strings.TrimSpace(a[:idx])) network = strings.ToLower(strings.TrimSpace(a[:idx]))
a = a[idx+1:] a = a[idx+1:]
} }
if network == "unix" { if network == "unix" || network == "unixgram" || network == "unixpacket" {
host = a host = a
return return
} }
@ -219,11 +219,11 @@ func SplitListenAddr(a string) (network, host, port string, err error) {
return return
} }
// JoinListenAddr combines network, host, and port into a single // JoinNetworkAddress combines network, host, and port into a single
// address string of the form "network/host:port". Port may be a // address string of the form "network/host:port". Port may be a
// port range. For unix sockets, the network should be "unix" and // port range. For unix sockets, the network should be "unix" and
// the path to the socket should be given in the host argument. // the path to the socket should be given in the host argument.
func JoinListenAddr(network, host, port string) string { func JoinNetworkAddress(network, host, port string) string {
var a string var a string
if network != "" { if network != "" {
a = network + "/" a = network + "/"

View file

@ -19,7 +19,7 @@ import (
"testing" "testing"
) )
func TestSplitListenerAddr(t *testing.T) { func TestSplitNetworkAddress(t *testing.T) {
for i, tc := range []struct { for i, tc := range []struct {
input string input string
expectNetwork string expectNetwork string
@ -67,8 +67,18 @@ func TestSplitListenerAddr(t *testing.T) {
expectNetwork: "unix", expectNetwork: "unix",
expectHost: "/foo/bar", expectHost: "/foo/bar",
}, },
{
input: "unixgram//foo/bar",
expectNetwork: "unixgram",
expectHost: "/foo/bar",
},
{
input: "unixpacket//foo/bar",
expectNetwork: "unixpacket",
expectHost: "/foo/bar",
},
} { } {
actualNetwork, actualHost, actualPort, err := SplitListenAddr(tc.input) actualNetwork, actualHost, actualPort, err := SplitNetworkAddress(tc.input)
if tc.expectErr && err == nil { if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err) t.Errorf("Test %d: Expected error but got: %v", i, err)
} }
@ -87,7 +97,7 @@ func TestSplitListenerAddr(t *testing.T) {
} }
} }
func TestJoinListenerAddr(t *testing.T) { func TestJoinNetworkAddress(t *testing.T) {
for i, tc := range []struct { for i, tc := range []struct {
network, host, port string network, host, port string
expect string expect string
@ -129,14 +139,14 @@ func TestJoinListenerAddr(t *testing.T) {
expect: "unix//foo/bar", expect: "unix//foo/bar",
}, },
} { } {
actual := JoinListenAddr(tc.network, tc.host, tc.port) actual := JoinNetworkAddress(tc.network, tc.host, tc.port)
if actual != tc.expect { if actual != tc.expect {
t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual) t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual)
} }
} }
} }
func TestParseListenerAddr(t *testing.T) { func TestParseNetworkAddress(t *testing.T) {
for i, tc := range []struct { for i, tc := range []struct {
input string input string
expectNetwork string expectNetwork string
@ -194,7 +204,7 @@ func TestParseListenerAddr(t *testing.T) {
expectAddrs: []string{"localhost:0"}, expectAddrs: []string{"localhost:0"},
}, },
} { } {
actualNetwork, actualAddrs, err := ParseListenAddr(tc.input) actualNetwork, actualAddrs, err := ParseNetworkAddress(tc.input)
if tc.expectErr && err == nil { if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err) t.Errorf("Test %d: Expected error but got: %v", i, err)
} }

View file

@ -108,7 +108,7 @@ func (app *App) Validate() error {
lnAddrs := make(map[string]string) lnAddrs := make(map[string]string)
for srvName, srv := range app.Servers { for srvName, srv := range app.Servers {
for _, addr := range srv.Listen { for _, addr := range srv.Listen {
netw, expanded, err := caddy.ParseListenAddr(addr) netw, expanded, err := caddy.ParseNetworkAddress(addr)
if err != nil { if err != nil {
return fmt.Errorf("invalid listener address '%s': %v", addr, err) return fmt.Errorf("invalid listener address '%s': %v", addr, err)
} }
@ -149,7 +149,7 @@ func (app *App) Start() error {
} }
for _, lnAddr := range srv.Listen { for _, lnAddr := range srv.Listen {
network, addrs, err := caddy.ParseListenAddr(lnAddr) network, addrs, err := caddy.ParseNetworkAddress(lnAddr)
if err != nil { if err != nil {
return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err) return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err)
} }
@ -309,7 +309,7 @@ func (app *App) automaticHTTPS() error {
// create HTTP->HTTPS redirects // create HTTP->HTTPS redirects
for _, addr := range srv.Listen { for _, addr := range srv.Listen {
netw, host, port, err := caddy.SplitListenAddr(addr) netw, host, port, err := caddy.SplitNetworkAddress(addr)
if err != nil { if err != nil {
return fmt.Errorf("%s: invalid listener address: %v", srvName, addr) return fmt.Errorf("%s: invalid listener address: %v", srvName, addr)
} }
@ -318,7 +318,7 @@ func (app *App) automaticHTTPS() error {
if httpPort == 0 { if httpPort == 0 {
httpPort = DefaultHTTPPort httpPort = DefaultHTTPPort
} }
httpRedirLnAddr := caddy.JoinListenAddr(netw, host, strconv.Itoa(httpPort)) httpRedirLnAddr := caddy.JoinNetworkAddress(netw, host, strconv.Itoa(httpPort))
lnAddrMap[httpRedirLnAddr] = struct{}{} lnAddrMap[httpRedirLnAddr] = struct{}{}
if parts := strings.SplitN(port, "-", 2); len(parts) == 2 { if parts := strings.SplitN(port, "-", 2); len(parts) == 2 {
@ -361,7 +361,7 @@ func (app *App) automaticHTTPS() error {
var lnAddrs []string var lnAddrs []string
mapLoop: mapLoop:
for addr := range lnAddrMap { for addr := range lnAddrMap {
netw, addrs, err := caddy.ParseListenAddr(addr) netw, addrs, err := caddy.ParseNetworkAddress(addr)
if err != nil { if err != nil {
continue continue
} }
@ -386,7 +386,7 @@ func (app *App) automaticHTTPS() error {
func (app *App) listenerTaken(network, address string) bool { func (app *App) listenerTaken(network, address string) bool {
for _, srv := range app.Servers { for _, srv := range app.Servers {
for _, addr := range srv.Listen { for _, addr := range srv.Listen {
netw, addrs, err := caddy.ParseListenAddr(addr) netw, addrs, err := caddy.ParseNetworkAddress(addr)
if err != nil || netw != network { if err != nil || netw != network {
continue continue
} }

View file

@ -25,6 +25,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy"
"github.com/caddyserver/caddy/v2/modules/caddytls" "github.com/caddyserver/caddy/v2/modules/caddytls"
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
@ -34,6 +35,7 @@ func init() {
caddy.RegisterModule(Transport{}) caddy.RegisterModule(Transport{})
} }
// Transport facilitates FastCGI communication.
type Transport struct { type Transport struct {
////////////////////////////// //////////////////////////////
// TODO: taken from v1 Handler type // TODO: taken from v1 Handler type
@ -57,32 +59,32 @@ type Transport struct {
// Use this directory as the fastcgi root directory. Defaults to the root // Use this directory as the fastcgi root directory. Defaults to the root
// directory of the parent virtual host. // directory of the parent virtual host.
Root string Root string `json:"root,omitempty"`
// The path in the URL will be split into two, with the first piece ending // The path in the URL will be split into two, with the first piece ending
// with the value of SplitPath. The first piece will be assumed as the // with the value of SplitPath. The first piece will be assumed as the
// actual resource (CGI script) name, and the second piece will be set to // actual resource (CGI script) name, and the second piece will be set to
// PATH_INFO for the CGI script to use. // PATH_INFO for the CGI script to use.
SplitPath string SplitPath string `json:"split_path,omitempty"`
// If the URL ends with '/' (which indicates a directory), these index // If the URL ends with '/' (which indicates a directory), these index
// files will be tried instead. // files will be tried instead.
IndexFiles []string // IndexFiles []string
// Environment Variables // Environment Variables
EnvVars [][2]string EnvVars [][2]string `json:"env,omitempty"`
// Ignored paths // Ignored paths
IgnoredSubPaths []string // IgnoredSubPaths []string
// The duration used to set a deadline when connecting to an upstream. // The duration used to set a deadline when connecting to an upstream.
DialTimeout time.Duration DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
// The duration used to set a deadline when reading from the FastCGI server. // The duration used to set a deadline when reading from the FastCGI server.
ReadTimeout time.Duration ReadTimeout caddy.Duration `json:"read_timeout,omitempty"`
// The duration used to set a deadline when sending to the FastCGI server. // The duration used to set a deadline when sending to the FastCGI server.
WriteTimeout time.Duration WriteTimeout caddy.Duration `json:"write_timeout,omitempty"`
} }
// CaddyModule returns the Caddy module information. // CaddyModule returns the Caddy module information.
@ -93,102 +95,62 @@ func (Transport) CaddyModule() caddy.ModuleInfo {
} }
} }
// RoundTrip implements http.RoundTripper.
func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) { func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
// Create environment for CGI script
env, err := t.buildEnv(r) env, err := t.buildEnv(r)
if err != nil { if err != nil {
return nil, fmt.Errorf("building environment: %v", err) return nil, fmt.Errorf("building environment: %v", err)
} }
// TODO: // TODO: doesn't dialer have a Timeout field?
// Connect to FastCGI gateway
// address, err := f.Address()
// if err != nil {
// return http.StatusBadGateway, err
// }
// network, address := parseAddress(address)
network, address := "tcp", r.URL.Host // TODO:
ctx := context.Background() ctx := context.Background()
if t.DialTimeout > 0 { if t.DialTimeout > 0 {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, t.DialTimeout) ctx, cancel = context.WithTimeout(ctx, time.Duration(t.DialTimeout))
defer cancel() defer cancel()
} }
// extract dial information from request (this
// should embedded by the reverse proxy)
network, address := "tcp", r.URL.Host
if dialInfoVal := ctx.Value(reverseproxy.DialInfoCtxKey); dialInfoVal != nil {
dialInfo := dialInfoVal.(reverseproxy.DialInfo)
network = dialInfo.Network
address = dialInfo.Address
}
fcgiBackend, err := DialContext(ctx, network, address) fcgiBackend, err := DialContext(ctx, network, address)
if err != nil { if err != nil {
return nil, fmt.Errorf("dialing backend: %v", err) return nil, fmt.Errorf("dialing backend: %v", err)
} }
// fcgiBackend is closed when response body is closed (see clientCloser) // fcgiBackend gets closed when response body is closed (see clientCloser)
// read/write timeouts // read/write timeouts
if err := fcgiBackend.SetReadTimeout(t.ReadTimeout); err != nil { if err := fcgiBackend.SetReadTimeout(time.Duration(t.ReadTimeout)); err != nil {
return nil, fmt.Errorf("setting read timeout: %v", err) return nil, fmt.Errorf("setting read timeout: %v", err)
} }
if err := fcgiBackend.SetWriteTimeout(t.WriteTimeout); err != nil { if err := fcgiBackend.SetWriteTimeout(time.Duration(t.WriteTimeout)); err != nil {
return nil, fmt.Errorf("setting write timeout: %v", err) return nil, fmt.Errorf("setting write timeout: %v", err)
} }
var resp *http.Response contentLength := r.ContentLength
if contentLength == 0 {
var contentLength int64
// if ContentLength is already set
if r.ContentLength > 0 {
contentLength = r.ContentLength
} else {
contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64) contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
} }
var resp *http.Response
switch r.Method { switch r.Method {
case "HEAD": case http.MethodHead:
resp, err = fcgiBackend.Head(env) resp, err = fcgiBackend.Head(env)
case "GET": case http.MethodGet:
resp, err = fcgiBackend.Get(env, r.Body, contentLength) resp, err = fcgiBackend.Get(env, r.Body, contentLength)
case "OPTIONS": case http.MethodOptions:
resp, err = fcgiBackend.Options(env) resp, err = fcgiBackend.Options(env)
default: default:
resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength) resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength)
} }
// TODO:
return resp, err return resp, err
// Stuff brought over from v1 that might not be necessary here:
// if resp != nil && resp.Body != nil {
// defer resp.Body.Close()
// }
// if err != nil {
// if err, ok := err.(net.Error); ok && err.Timeout() {
// return http.StatusGatewayTimeout, err
// } else if err != io.EOF {
// return http.StatusBadGateway, err
// }
// }
// // Write response header
// writeHeader(w, resp)
// // Write the response body
// _, err = io.Copy(w, resp.Body)
// if err != nil {
// return http.StatusBadGateway, err
// }
// // Log any stderr output from upstream
// if fcgiBackend.stderr.Len() != 0 {
// // Remove trailing newline, error logger already does this.
// err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
// }
// // Normally we would return the status code if it is an error status (>= 400),
// // however, upstream FastCGI apps don't know about our contract and have
// // probably already written an error page. So we just return 0, indicating
// // that the response body is already written. However, we do return any
// // error value so it can be logged.
// // Note that the proxy middleware works the same way, returning status=0.
// return 0, err
} }
// buildEnv returns a set of CGI environment variables for the request. // buildEnv returns a set of CGI environment variables for the request.

View file

@ -15,6 +15,7 @@
package reverseproxy package reverseproxy
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -93,15 +94,31 @@ func (h *Handler) activeHealthChecker() {
// health checks for all hosts in the global repository. // health checks for all hosts in the global repository.
func (h *Handler) doActiveHealthChecksForAllHosts() { func (h *Handler) doActiveHealthChecksForAllHosts() {
hosts.Range(func(key, value interface{}) bool { hosts.Range(func(key, value interface{}) bool {
addr := key.(string) networkAddr := key.(string)
host := value.(Host) host := value.(Host)
go func(addr string, host Host) { go func(networkAddr string, host Host) {
err := h.doActiveHealthCheck(addr, host) network, addrs, err := caddy.ParseNetworkAddress(networkAddr)
if err != nil { if err != nil {
log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", addr, err) log.Printf("[ERROR] reverse_proxy: active health check for host %s: bad network address: %v", networkAddr, err)
return
} }
}(addr, host) if len(addrs) != 1 {
log.Printf("[ERROR] reverse_proxy: active health check for host %s: multiple addresses (upstream must map to only one address)", networkAddr)
return
}
hostAddr := addrs[0]
if network == "unix" || network == "unixgram" || network == "unixpacket" {
// this will be used as the Host portion of a http.Request URL, and
// paths to socket files would produce an error when creating URL,
// so use a fake Host value instead
hostAddr = network
}
err = h.doActiveHealthCheck(DialInfo{network, addrs[0]}, hostAddr, host)
if err != nil {
log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", networkAddr, err)
}
}(networkAddr, host)
// continue to iterate all hosts // continue to iterate all hosts
return true return true
@ -115,26 +132,39 @@ func (h *Handler) doActiveHealthChecksForAllHosts() {
// according to whether it passes the health check. An error is // according to whether it passes the health check. An error is
// returned only if the health check fails to occur or if marking // returned only if the health check fails to occur or if marking
// the host's health status fails. // the host's health status fails.
func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error { func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host Host) error {
// create the URL for the health check // create the URL for the request that acts as a health check
u, err := url.Parse(hostAddr) scheme := "http"
if err != nil { if ht, ok := h.Transport.(*http.Transport); ok && ht.TLSClientConfig != nil {
return err // this is kind of a hacky way to know if we should use HTTPS, but whatever
scheme = "https"
} }
if h.HealthChecks.Active.Path != "" { u := &url.URL{
u.Path = h.HealthChecks.Active.Path Scheme: scheme,
Host: hostAddr,
Path: h.HealthChecks.Active.Path,
} }
// adjust the port, if configured to be different
if h.HealthChecks.Active.Port != 0 { if h.HealthChecks.Active.Port != 0 {
portStr := strconv.Itoa(h.HealthChecks.Active.Port) portStr := strconv.Itoa(h.HealthChecks.Active.Port)
u.Host = net.JoinHostPort(u.Hostname(), portStr) host, _, err := net.SplitHostPort(hostAddr)
}
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil { if err != nil {
return err host = hostAddr
}
u.Host = net.JoinHostPort(host, portStr)
} }
// do the request, careful to tame the response body // attach dialing information to this request
ctx := context.Background()
ctx = context.WithValue(ctx, caddy.ReplacerCtxKey, caddy.NewReplacer())
ctx = context.WithValue(ctx, DialInfoCtxKey, dialInfo)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return fmt.Errorf("making request: %v", err)
}
// do the request, being careful to tame the response body
resp, err := h.HealthChecks.Active.httpClient.Do(req) resp, err := h.HealthChecks.Active.httpClient.Do(req)
if err != nil { if err != nil {
log.Printf("[INFO] reverse_proxy: active health check: %s is down (HTTP request failed: %v)", hostAddr, err) log.Printf("[INFO] reverse_proxy: active health check: %s is down (HTTP request failed: %v)", hostAddr, err)
@ -149,7 +179,7 @@ func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error {
body = io.LimitReader(body, h.HealthChecks.Active.MaxSize) body = io.LimitReader(body, h.HealthChecks.Active.MaxSize)
} }
defer func() { defer func() {
// drain any remaining body so connection can be re-used // drain any remaining body so connection could be re-used
io.Copy(ioutil.Discard, body) io.Copy(ioutil.Discard, body)
resp.Body.Close() resp.Body.Close()
}() }()
@ -225,7 +255,7 @@ func (h *Handler) countFailure(upstream *Upstream) {
err := upstream.Host.CountFail(1) err := upstream.Host.CountFail(1)
if err != nil { if err != nil {
log.Printf("[ERROR] proxy: upstream %s: counting failure: %v", log.Printf("[ERROR] proxy: upstream %s: counting failure: %v",
upstream.hostURL, err) upstream.dialInfo, err)
} }
// forget it later // forget it later
@ -234,7 +264,7 @@ func (h *Handler) countFailure(upstream *Upstream) {
err := host.CountFail(-1) err := host.CountFail(-1)
if err != nil { if err != nil {
log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v", log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v",
upstream.hostURL, err) upstream.dialInfo, err)
} }
}(upstream.Host, failDuration) }(upstream.Host, failDuration)
} }

View file

@ -16,7 +16,6 @@ package reverseproxy
import ( import (
"fmt" "fmt"
"net/url"
"sync/atomic" "sync/atomic"
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
@ -59,7 +58,7 @@ type UpstreamPool []*Upstream
type Upstream struct { type Upstream struct {
Host `json:"-"` Host `json:"-"`
Address string `json:"address,omitempty"` Dial string `json:"dial,omitempty"`
MaxRequests int `json:"max_requests,omitempty"` MaxRequests int `json:"max_requests,omitempty"`
// TODO: This could be really useful, to bind requests // TODO: This could be really useful, to bind requests
@ -68,8 +67,8 @@ type Upstream struct {
// IPAffinity string // IPAffinity string
healthCheckPolicy *PassiveHealthChecks healthCheckPolicy *PassiveHealthChecks
hostURL *url.URL
cb CircuitBreaker cb CircuitBreaker
dialInfo DialInfo
} }
// Available returns true if the remote host // Available returns true if the remote host
@ -101,11 +100,6 @@ func (u *Upstream) Full() bool {
return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests
} }
// URL returns the upstream host's endpoint URL.
func (u *Upstream) URL() *url.URL {
return u.hostURL
}
// upstreamHost is the basic, in-memory representation // upstreamHost is the basic, in-memory representation
// of the state of a remote host. It implements the // of the state of a remote host. It implements the
// Host interface. // Host interface.
@ -162,6 +156,34 @@ func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) {
return swapped, nil return swapped, nil
} }
// DialInfo contains information needed to dial a
// connection to an upstream host. This information
// may be different than that which is represented
// in a URL (for example, unix sockets don't have
// a host that can be represented in a URL, but
// they certainly have a network name and address).
type DialInfo struct {
// The network to use. This should be one of the
// values that is accepted by net.Dial:
// https://golang.org/pkg/net/#Dial
Network string
// The address to dial. Follows the same
// semantics and rules as net.Dial.
Address string
}
// String returns the Caddy network address form
// by joining the network and address with a
// forward slash.
func (di DialInfo) String() string {
return di.Network + "/" + di.Address
}
// DialInfoCtxKey is used to store a DialInfo
// in a context.Context.
const DialInfoCtxKey = caddy.CtxKey("dial_info")
// hosts is the global repository for hosts that are // hosts is the global repository for hosts that are
// currently in use by active configuration(s). This // currently in use by active configuration(s). This
// allows the state of remote hosts to be preserved // allows the state of remote hosts to be preserved

View file

@ -15,6 +15,7 @@
package reverseproxy package reverseproxy
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
@ -63,14 +64,23 @@ func (HTTPTransport) CaddyModule() caddy.ModuleInfo {
// Provision sets up h.RoundTripper with a http.Transport // Provision sets up h.RoundTripper with a http.Transport
// that is ready to use. // that is ready to use.
func (h *HTTPTransport) Provision(ctx caddy.Context) error { func (h *HTTPTransport) Provision(_ caddy.Context) error {
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: time.Duration(h.DialTimeout), Timeout: time.Duration(h.DialTimeout),
FallbackDelay: time.Duration(h.FallbackDelay), FallbackDelay: time.Duration(h.FallbackDelay),
// TODO: Resolver // TODO: Resolver
} }
rt := &http.Transport{ rt := &http.Transport{
DialContext: dialer.DialContext, DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
// the proper dialing information should be embedded into the request's context
if dialInfoVal := ctx.Value(DialInfoCtxKey); dialInfoVal != nil {
dialInfo := dialInfoVal.(DialInfo)
network = dialInfo.Network
address = dialInfo.Address
}
return dialer.DialContext(ctx, network, address)
},
MaxConnsPerHost: h.MaxConnsPerHost, MaxConnsPerHost: h.MaxConnsPerHost,
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout), ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout), ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout),
@ -91,7 +101,6 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error {
if h.KeepAlive != nil { if h.KeepAlive != nil {
dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval) dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval)
if enabled := h.KeepAlive.Enabled; enabled != nil { if enabled := h.KeepAlive.Enabled; enabled != nil {
rt.DisableKeepAlives = !*enabled rt.DisableKeepAlives = !*enabled
} }
@ -191,16 +200,3 @@ type KeepAlive struct {
MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"` MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"`
IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle
} }
var (
defaultDialer = net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
defaultTransport = &http.Transport{
DialContext: defaultDialer.DialContext,
TLSHandshakeTimeout: 5 * time.Second,
IdleConnTimeout: 2 * time.Minute,
}
)

View file

@ -20,7 +20,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/url"
"regexp" "regexp"
"strings" "strings"
"time" "time"
@ -86,7 +85,18 @@ func (h *Handler) Provision(ctx caddy.Context) error {
} }
if h.Transport == nil { if h.Transport == nil {
h.Transport = defaultTransport t := &HTTPTransport{
KeepAlive: &KeepAlive{
ProbeInterval: caddy.Duration(30 * time.Second),
IdleConnTimeout: caddy.Duration(2 * time.Minute),
},
DialTimeout: caddy.Duration(10 * time.Second),
}
err := t.Provision(ctx)
if err != nil {
return fmt.Errorf("provisioning default transport: %v", err)
}
h.Transport = t
} }
if h.LoadBalancing == nil { if h.LoadBalancing == nil {
@ -133,51 +143,65 @@ func (h *Handler) Provision(ctx caddy.Context) error {
go h.activeHealthChecker() go h.activeHealthChecker()
} }
var allUpstreams []*Upstream
for _, upstream := range h.Upstreams { for _, upstream := range h.Upstreams {
upstream.cb = h.CB // upstreams are allowed to map to only a single host,
// but an upstream's address may semantically represent
// url parser requires a scheme // multiple addresses, so make sure to handle each
if !strings.Contains(upstream.Address, "://") { // one in turn based on this one upstream config
upstream.Address = "http://" + upstream.Address network, addresses, err := caddy.ParseNetworkAddress(upstream.Dial)
}
u, err := url.Parse(upstream.Address)
if err != nil { if err != nil {
return fmt.Errorf("invalid upstream address %s: %v", upstream.Address, err) return fmt.Errorf("parsing dial address: %v", err)
} }
upstream.hostURL = u
for _, addr := range addresses {
// make a new upstream based on the original
// that has a singular dial address
upstreamCopy := *upstream
upstreamCopy.dialInfo = DialInfo{network, addr}
upstreamCopy.Dial = upstreamCopy.dialInfo.String()
upstreamCopy.cb = h.CB
// if host already exists from a current config, // if host already exists from a current config,
// use that instead; otherwise, add it // use that instead; otherwise, add it
// TODO: make hosts modular, so that their state can be distributed in enterprise for example // TODO: make hosts modular, so that their state can be distributed in enterprise for example
// TODO: If distributed, the pool should be stored in storage... // TODO: If distributed, the pool should be stored in storage...
var host Host = new(upstreamHost) var host Host = new(upstreamHost)
activeHost, loaded := hosts.LoadOrStore(u.String(), host) activeHost, loaded := hosts.LoadOrStore(upstreamCopy.Dial, host)
if loaded { if loaded {
host = activeHost.(Host) host = activeHost.(Host)
} }
upstream.Host = host upstreamCopy.Host = host
// if the passive health checker has a non-zero "unhealthy // if the passive health checker has a non-zero "unhealthy
// request count" but the upstream has no MaxRequests set // request count" but the upstream has no MaxRequests set
// (they are the same thing, but one is a default value for // (they are the same thing, but one is a default value for
// for upstreams with a zero MaxRequests), copy the default // for upstreams with a zero MaxRequests), copy the default
// value into this upstream, since the value in the upstream // value into this upstream, since the value in the upstream
// is what is used during availability checks // (MaxRequests) is what is used during availability checks
if h.HealthChecks != nil && if h.HealthChecks != nil &&
h.HealthChecks.Passive != nil && h.HealthChecks.Passive != nil &&
h.HealthChecks.Passive.UnhealthyRequestCount > 0 && h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
upstream.MaxRequests == 0 { upstreamCopy.MaxRequests == 0 {
upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount upstreamCopy.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
} }
if h.HealthChecks != nil {
// upstreams need independent access to the passive // upstreams need independent access to the passive
// health check policy so they can, you know, passively // health check policy because they run outside of the
// do health checks // scope of a request handler
upstream.healthCheckPolicy = h.HealthChecks.Passive if h.HealthChecks != nil {
upstreamCopy.healthCheckPolicy = h.HealthChecks.Passive
}
allUpstreams = append(allUpstreams, &upstreamCopy)
} }
} }
// replace the unmarshaled upstreams (possible 1:many
// address mapping) with our list, which is mapped 1:1,
// thus may have expanded the original list
h.Upstreams = allUpstreams
return nil return nil
} }
@ -192,7 +216,7 @@ func (h *Handler) Cleanup() error {
// remove hosts from our config from the pool // remove hosts from our config from the pool
for _, upstream := range h.Upstreams { for _, upstream := range h.Upstreams {
hosts.Delete(upstream.hostURL.String()) hosts.Delete(upstream.dialInfo.String())
} }
return nil return nil
@ -222,6 +246,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
continue continue
} }
// attach to the request information about how to dial the upstream;
// this is necessary because the information cannot be sufficiently
// or satisfactorily represented in a URL
ctx := context.WithValue(r.Context(), DialInfoCtxKey, upstream.dialInfo)
r = r.WithContext(ctx)
// proxy the request to that upstream // proxy the request to that upstream
proxyErr = h.reverseProxy(w, r, upstream) proxyErr = h.reverseProxy(w, r, upstream)
if proxyErr == nil || proxyErr == context.Canceled { if proxyErr == nil || proxyErr == context.Canceled {
@ -249,6 +279,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
// This assumes that no mutations of the request are performed // This assumes that no mutations of the request are performed
// by h during or after proxying. // by h during or after proxying.
func (h Handler) prepareRequest(req *http.Request) error { func (h Handler) prepareRequest(req *http.Request) error {
// as a special (but very common) case, if the transport
// is HTTP, then ensure the request has the proper scheme
// because incoming requests by default are lacking it
if req.URL.Scheme == "" {
req.URL.Scheme = "http"
if ht, ok := h.Transport.(*HTTPTransport); ok && ht.TLS != nil {
req.URL.Scheme = "https"
}
}
if req.ContentLength == 0 { if req.ContentLength == 0 {
req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries
} }
@ -433,14 +473,8 @@ func (h Handler) tryAgain(start time.Time, proxyErr error) bool {
// directRequest modifies only req.URL so that it points to the // directRequest modifies only req.URL so that it points to the
// given upstream host. It must modify ONLY the request URL. // given upstream host. It must modify ONLY the request URL.
func (h Handler) directRequest(req *http.Request, upstream *Upstream) { func (h Handler) directRequest(req *http.Request, upstream *Upstream) {
target := upstream.hostURL if req.URL.Host == "" {
req.URL.Scheme = target.Scheme req.URL.Host = upstream.dialInfo.Address
req.URL.Host = target.Host
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) // TODO: This might be a bug (if any part of the path was augmented from a previously-tried upstream; need to start from clean original path of request, same for query string!)
if target.RawQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = target.RawQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = target.RawQuery + "&" + req.URL.RawQuery
} }
} }

View file

@ -168,7 +168,7 @@ func (s *Server) enforcementHandler(w http.ResponseWriter, r *http.Request, next
// listeners in s that use a port which is not otherPort. // listeners in s that use a port which is not otherPort.
func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool { func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool {
for _, lnAddr := range s.Listen { for _, lnAddr := range s.Listen {
_, addrs, err := caddy.ParseListenAddr(lnAddr) _, addrs, err := caddy.ParseNetworkAddress(lnAddr)
if err == nil { if err == nil {
for _, a := range addrs { for _, a := range addrs {
_, port, err := net.SplitHostPort(a) _, port, err := net.SplitHostPort(a)