mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-13 22:36:27 +03:00
core: Use port ranges to avoid OOM with bad inputs (#2859)
* fix OOM issue caught by fuzzing * use ParsedAddress as the struct name for the result of ParseNetworkAddress * simplify code using the ParsedAddress type * minor cleanups
This commit is contained in:
parent
a19da07b72
commit
93bc1b72e3
8 changed files with 201 additions and 130 deletions
16
admin.go
16
admin.go
|
@ -48,23 +48,19 @@ type AdminConfig struct {
|
||||||
|
|
||||||
// listenAddr extracts a singular listen address from ac.Listen,
|
// listenAddr extracts a singular listen address from ac.Listen,
|
||||||
// returning the network and the address of the listener.
|
// returning the network and the address of the listener.
|
||||||
func (admin AdminConfig) listenAddr() (netw string, addr string, err error) {
|
func (admin AdminConfig) listenAddr() (string, string, error) {
|
||||||
var listenAddrs []string
|
|
||||||
input := admin.Listen
|
input := admin.Listen
|
||||||
if input == "" {
|
if input == "" {
|
||||||
input = DefaultAdminListen
|
input = DefaultAdminListen
|
||||||
}
|
}
|
||||||
netw, listenAddrs, err = ParseNetworkAddress(input)
|
listenAddr, err := ParseNetworkAddress(input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("parsing admin listener address: %v", err)
|
return "", "", fmt.Errorf("parsing admin listener address: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if len(listenAddrs) != 1 {
|
if listenAddr.PortRangeSize() != 1 {
|
||||||
err = fmt.Errorf("admin endpoint must have exactly one address; cannot listen on %v", listenAddrs)
|
return "", "", fmt.Errorf("admin endpoint must have exactly one address; cannot listen on %v", listenAddr)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
addr = listenAddrs[0]
|
return listenAddr.Network, listenAddr.JoinHostPort(0), nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newAdminHandler reads admin's config and returns an http.Handler suitable
|
// newAdminHandler reads admin's config and returns an http.Handler suitable
|
||||||
|
|
101
listeners.go
101
listeners.go
|
@ -257,52 +257,94 @@ type globalListener struct {
|
||||||
pc net.PacketConn
|
pc net.PacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
// ParsedAddress contains the individual components
|
||||||
listeners = make(map[string]*globalListener)
|
// for a parsed network address of the form accepted
|
||||||
listenersMu sync.Mutex
|
// by ParseNetworkAddress(). Network should be a
|
||||||
)
|
// network value accepted by Go's net package. Port
|
||||||
|
// ranges are given by [StartPort, EndPort].
|
||||||
|
type ParsedAddress struct {
|
||||||
|
Network string
|
||||||
|
Host string
|
||||||
|
StartPort uint
|
||||||
|
EndPort uint
|
||||||
|
}
|
||||||
|
|
||||||
// ParseNetworkAddress parses addr, a string of the form "network/host:port"
|
// JoinHostPort is like net.JoinHostPort, but where the port
|
||||||
// (with any part optional) into its component parts. Because a port can
|
// is StartPort + offset.
|
||||||
// also be a port range, there may be multiple addresses returned.
|
func (l ParsedAddress) JoinHostPort(offset uint) string {
|
||||||
func ParseNetworkAddress(addr string) (network string, addrs []string, err error) {
|
return net.JoinHostPort(l.Host, strconv.Itoa(int(l.StartPort+offset)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PortRangeSize returns how many ports are in
|
||||||
|
// pa's port range. Port ranges are inclusive,
|
||||||
|
// so the size is the difference of start and
|
||||||
|
// end ports plus one.
|
||||||
|
func (pa ParsedAddress) PortRangeSize() uint {
|
||||||
|
return (pa.EndPort - pa.StartPort) + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// String reconstructs the address string to the form expected
|
||||||
|
// by ParseNetworkAddress().
|
||||||
|
func (pa ParsedAddress) String() string {
|
||||||
|
port := strconv.FormatUint(uint64(pa.StartPort), 10)
|
||||||
|
if pa.StartPort != pa.EndPort {
|
||||||
|
port += "-" + strconv.FormatUint(uint64(pa.EndPort), 10)
|
||||||
|
}
|
||||||
|
return JoinNetworkAddress(pa.Network, pa.Host, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseNetworkAddress parses addr into its individual
|
||||||
|
// components. The input string is expected to be of
|
||||||
|
// the form "network/host:port-range" where any part is
|
||||||
|
// optional. The default network, if unspecified, is tcp.
|
||||||
|
// Port ranges are inclusive.
|
||||||
|
//
|
||||||
|
// Network addresses are distinct from URLs and do not
|
||||||
|
// use URL syntax.
|
||||||
|
func ParseNetworkAddress(addr string) (ParsedAddress, error) {
|
||||||
var host, port string
|
var host, port string
|
||||||
network, host, port, err = SplitNetworkAddress(addr)
|
network, host, port, err := SplitNetworkAddress(addr)
|
||||||
if network == "" {
|
if network == "" {
|
||||||
network = "tcp"
|
network = "tcp"
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return ParsedAddress{}, err
|
||||||
}
|
}
|
||||||
if network == "unix" || network == "unixgram" || network == "unixpacket" {
|
if network == "unix" || network == "unixgram" || network == "unixpacket" {
|
||||||
addrs = []string{host}
|
return ParsedAddress{
|
||||||
return
|
Network: network,
|
||||||
|
Host: host,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
ports := strings.SplitN(port, "-", 2)
|
ports := strings.SplitN(port, "-", 2)
|
||||||
if len(ports) == 1 {
|
if len(ports) == 1 {
|
||||||
ports = append(ports, ports[0])
|
ports = append(ports, ports[0])
|
||||||
}
|
}
|
||||||
var start, end int
|
var start, end uint64
|
||||||
start, err = strconv.Atoi(ports[0])
|
start, err = strconv.ParseUint(ports[0], 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return ParsedAddress{}, fmt.Errorf("invalid start port: %v", err)
|
||||||
}
|
}
|
||||||
end, err = strconv.Atoi(ports[1])
|
end, err = strconv.ParseUint(ports[1], 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return ParsedAddress{}, fmt.Errorf("invalid end port: %v", err)
|
||||||
}
|
}
|
||||||
if end < start {
|
if end < start {
|
||||||
err = fmt.Errorf("end port must be greater than start port")
|
return ParsedAddress{}, fmt.Errorf("end port must not be less than start port")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
for p := start; p <= end; p++ {
|
if (end - start) > maxPortSpan {
|
||||||
addrs = append(addrs, net.JoinHostPort(host, fmt.Sprintf("%d", p)))
|
return ParsedAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan)
|
||||||
}
|
}
|
||||||
return
|
return ParsedAddress{
|
||||||
|
Network: network,
|
||||||
|
Host: host,
|
||||||
|
StartPort: uint(start),
|
||||||
|
EndPort: uint(end),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SplitNetworkAddress splits a into its network, host, and port components.
|
// SplitNetworkAddress splits a into its network, host, and port components.
|
||||||
// Note that port may be a port range, or omitted for unix sockets.
|
// Note that port may be a port range (:X-Y), or omitted for unix sockets.
|
||||||
func SplitNetworkAddress(a string) (network, host, port string, err error) {
|
func SplitNetworkAddress(a string) (network, host, port string, err error) {
|
||||||
if idx := strings.Index(a, "/"); idx >= 0 {
|
if idx := strings.Index(a, "/"); idx >= 0 {
|
||||||
network = strings.ToLower(strings.TrimSpace(a[:idx]))
|
network = strings.ToLower(strings.TrimSpace(a[:idx]))
|
||||||
|
@ -317,9 +359,9 @@ func SplitNetworkAddress(a string) (network, host, port string, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// JoinNetworkAddress combines network, host, and port into a single
|
// JoinNetworkAddress combines network, host, and port into a single
|
||||||
// address string of the form "network/host:port". Port may be a
|
// address string of the form accepted by ParseNetworkAddress(). For unix sockets, the network
|
||||||
// port range. For unix sockets, the network should be "unix" and
|
// should be "unix" and the path to the socket should be given as the
|
||||||
// the path to the socket should be given in the host argument.
|
// host parameter.
|
||||||
func JoinNetworkAddress(network, host, port string) string {
|
func JoinNetworkAddress(network, host, port string) string {
|
||||||
var a string
|
var a string
|
||||||
if network != "" {
|
if network != "" {
|
||||||
|
@ -332,3 +374,10 @@ func JoinNetworkAddress(network, host, port string) string {
|
||||||
}
|
}
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
listeners = make(map[string]*globalListener)
|
||||||
|
listenersMu sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxPortSpan = 65535
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
package caddy
|
package caddy
|
||||||
|
|
||||||
func FuzzParseNetworkAddress(data []byte) int {
|
func FuzzParseNetworkAddress(data []byte) int {
|
||||||
_, _, err := ParseNetworkAddress(string(data))
|
_, err := ParseNetworkAddress(string(data))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
|
@ -152,74 +152,101 @@ func TestJoinNetworkAddress(t *testing.T) {
|
||||||
|
|
||||||
func TestParseNetworkAddress(t *testing.T) {
|
func TestParseNetworkAddress(t *testing.T) {
|
||||||
for i, tc := range []struct {
|
for i, tc := range []struct {
|
||||||
input string
|
input string
|
||||||
expectNetwork string
|
expectAddr ParsedAddress
|
||||||
expectAddrs []string
|
expectErr bool
|
||||||
expectErr bool
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
input: "",
|
input: "",
|
||||||
expectNetwork: "tcp",
|
expectErr: true,
|
||||||
expectErr: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: ":",
|
input: ":",
|
||||||
expectNetwork: "tcp",
|
expectErr: true,
|
||||||
expectErr: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: ":1234",
|
input: ":1234",
|
||||||
expectNetwork: "tcp",
|
expectAddr: ParsedAddress{
|
||||||
expectAddrs: []string{":1234"},
|
Network: "tcp",
|
||||||
|
Host: "",
|
||||||
|
StartPort: 1234,
|
||||||
|
EndPort: 1234,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: "tcp/:1234",
|
input: "tcp/:1234",
|
||||||
expectNetwork: "tcp",
|
expectAddr: ParsedAddress{
|
||||||
expectAddrs: []string{":1234"},
|
Network: "tcp",
|
||||||
|
Host: "",
|
||||||
|
StartPort: 1234,
|
||||||
|
EndPort: 1234,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: "tcp6/:1234",
|
input: "tcp6/:1234",
|
||||||
expectNetwork: "tcp6",
|
expectAddr: ParsedAddress{
|
||||||
expectAddrs: []string{":1234"},
|
Network: "tcp6",
|
||||||
|
Host: "",
|
||||||
|
StartPort: 1234,
|
||||||
|
EndPort: 1234,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: "tcp4/localhost:1234",
|
input: "tcp4/localhost:1234",
|
||||||
expectNetwork: "tcp4",
|
expectAddr: ParsedAddress{
|
||||||
expectAddrs: []string{"localhost:1234"},
|
Network: "tcp4",
|
||||||
|
Host: "localhost",
|
||||||
|
StartPort: 1234,
|
||||||
|
EndPort: 1234,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: "unix//foo/bar",
|
input: "unix//foo/bar",
|
||||||
expectNetwork: "unix",
|
expectAddr: ParsedAddress{
|
||||||
expectAddrs: []string{"/foo/bar"},
|
Network: "unix",
|
||||||
|
Host: "/foo/bar",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: "localhost:1234-1234",
|
input: "localhost:1234-1234",
|
||||||
expectNetwork: "tcp",
|
expectAddr: ParsedAddress{
|
||||||
expectAddrs: []string{"localhost:1234"},
|
Network: "tcp",
|
||||||
|
Host: "localhost",
|
||||||
|
StartPort: 1234,
|
||||||
|
EndPort: 1234,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: "localhost:2-1",
|
input: "localhost:2-1",
|
||||||
expectNetwork: "tcp",
|
expectErr: true,
|
||||||
expectErr: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: "localhost:0",
|
input: "localhost:0",
|
||||||
expectNetwork: "tcp",
|
expectAddr: ParsedAddress{
|
||||||
expectAddrs: []string{"localhost:0"},
|
Network: "tcp",
|
||||||
|
Host: "localhost",
|
||||||
|
StartPort: 0,
|
||||||
|
EndPort: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "localhost:1-999999999999",
|
||||||
|
expectErr: true,
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
actualNetwork, actualAddrs, err := ParseNetworkAddress(tc.input)
|
actualAddr, err := ParseNetworkAddress(tc.input)
|
||||||
if tc.expectErr && err == nil {
|
if tc.expectErr && err == nil {
|
||||||
t.Errorf("Test %d: Expected error but got: %v", i, err)
|
t.Errorf("Test %d: Expected error but got: %v", i, err)
|
||||||
}
|
}
|
||||||
if !tc.expectErr && err != nil {
|
if !tc.expectErr && err != nil {
|
||||||
t.Errorf("Test %d: Expected no error but got: %v", i, err)
|
t.Errorf("Test %d: Expected no error but got: %v", i, err)
|
||||||
}
|
}
|
||||||
if actualNetwork != tc.expectNetwork {
|
|
||||||
t.Errorf("Test %d: Expected network '%s' but got '%s'", i, tc.expectNetwork, actualNetwork)
|
if actualAddr.Network != tc.expectAddr.Network {
|
||||||
|
t.Errorf("Test %d: Expected network '%v' but got '%v'", i, tc.expectAddr, actualAddr)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(tc.expectAddrs, actualAddrs) {
|
if !reflect.DeepEqual(tc.expectAddr, actualAddr) {
|
||||||
t.Errorf("Test %d: Expected addresses %v but got %v", i, tc.expectAddrs, actualAddrs)
|
t.Errorf("Test %d: Expected addresses %v but got %v", i, tc.expectAddr, actualAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -135,15 +135,18 @@ func (app *App) Validate() error {
|
||||||
lnAddrs := make(map[string]string)
|
lnAddrs := make(map[string]string)
|
||||||
for srvName, srv := range app.Servers {
|
for srvName, srv := range app.Servers {
|
||||||
for _, addr := range srv.Listen {
|
for _, addr := range srv.Listen {
|
||||||
netw, expanded, err := caddy.ParseNetworkAddress(addr)
|
listenAddr, err := caddy.ParseNetworkAddress(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid listener address '%s': %v", addr, err)
|
return fmt.Errorf("invalid listener address '%s': %v", addr, err)
|
||||||
}
|
}
|
||||||
for _, a := range expanded {
|
// check that every address in the port range is unique to this server;
|
||||||
if sn, ok := lnAddrs[netw+a]; ok {
|
// we do not use <= here because PortRangeSize() adds 1 to EndPort for us
|
||||||
return fmt.Errorf("server %s: listener address repeated: %s (already claimed by server '%s')", srvName, a, sn)
|
for i := uint(0); i < listenAddr.PortRangeSize(); i++ {
|
||||||
|
addr := caddy.JoinNetworkAddress(listenAddr.Network, listenAddr.Host, strconv.Itoa(int(listenAddr.StartPort+i)))
|
||||||
|
if sn, ok := lnAddrs[addr]; ok {
|
||||||
|
return fmt.Errorf("server %s: listener address repeated: %s (already claimed by server '%s')", srvName, addr, sn)
|
||||||
}
|
}
|
||||||
lnAddrs[netw+a] = srvName
|
lnAddrs[addr] = srvName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -176,14 +179,15 @@ func (app *App) Start() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, lnAddr := range srv.Listen {
|
for _, lnAddr := range srv.Listen {
|
||||||
network, addrs, err := caddy.ParseNetworkAddress(lnAddr)
|
listenAddr, err := caddy.ParseNetworkAddress(lnAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err)
|
return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err)
|
||||||
}
|
}
|
||||||
for _, addr := range addrs {
|
for i := uint(0); i <= listenAddr.PortRangeSize(); i++ {
|
||||||
ln, err := caddy.Listen(network, addr)
|
hostport := listenAddr.JoinHostPort(i)
|
||||||
|
ln, err := caddy.Listen(listenAddr.Network, hostport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s: listening on %s: %v", network, addr, err)
|
return fmt.Errorf("%s: listening on %s: %v", listenAddr.Network, hostport, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// enable HTTP/2 by default
|
// enable HTTP/2 by default
|
||||||
|
@ -194,11 +198,10 @@ func (app *App) Start() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// enable TLS
|
// enable TLS
|
||||||
_, port, _ := net.SplitHostPort(addr)
|
if len(srv.TLSConnPolicies) > 0 && int(i) != app.httpPort() {
|
||||||
if len(srv.TLSConnPolicies) > 0 && port != strconv.Itoa(app.httpPort()) {
|
|
||||||
tlsCfg, err := srv.TLSConnPolicies.TLSConfig(app.ctx)
|
tlsCfg, err := srv.TLSConnPolicies.TLSConfig(app.ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s/%s: making TLS configuration: %v", network, addr, err)
|
return fmt.Errorf("%s/%s: making TLS configuration: %v", listenAddr.Network, hostport, err)
|
||||||
}
|
}
|
||||||
ln = tls.NewListener(ln, tlsCfg)
|
ln = tls.NewListener(ln, tlsCfg)
|
||||||
|
|
||||||
|
@ -206,15 +209,15 @@ func (app *App) Start() error {
|
||||||
// TODO: HTTP/3 support is experimental for now
|
// TODO: HTTP/3 support is experimental for now
|
||||||
if srv.ExperimentalHTTP3 {
|
if srv.ExperimentalHTTP3 {
|
||||||
app.logger.Info("enabling experimental HTTP/3 listener",
|
app.logger.Info("enabling experimental HTTP/3 listener",
|
||||||
zap.String("addr", addr),
|
zap.String("addr", hostport),
|
||||||
)
|
)
|
||||||
h3ln, err := caddy.ListenPacket("udp", addr)
|
h3ln, err := caddy.ListenPacket("udp", hostport)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("getting HTTP/3 UDP listener: %v", err)
|
return fmt.Errorf("getting HTTP/3 UDP listener: %v", err)
|
||||||
}
|
}
|
||||||
h3srv := &http3.Server{
|
h3srv := &http3.Server{
|
||||||
Server: &http.Server{
|
Server: &http.Server{
|
||||||
Addr: addr,
|
Addr: hostport,
|
||||||
Handler: srv,
|
Handler: srv,
|
||||||
TLSConfig: tlsCfg,
|
TLSConfig: tlsCfg,
|
||||||
},
|
},
|
||||||
|
|
|
@ -102,7 +102,7 @@ func (h *Handler) doActiveHealthChecksForAllHosts() {
|
||||||
host := value.(Host)
|
host := value.(Host)
|
||||||
|
|
||||||
go func(networkAddr string, host Host) {
|
go func(networkAddr string, host Host) {
|
||||||
network, addrs, err := caddy.ParseNetworkAddress(networkAddr)
|
addr, err := caddy.ParseNetworkAddress(networkAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.HealthChecks.Active.logger.Error("bad network address",
|
h.HealthChecks.Active.logger.Error("bad network address",
|
||||||
zap.String("address", networkAddr),
|
zap.String("address", networkAddr),
|
||||||
|
@ -110,20 +110,20 @@ func (h *Handler) doActiveHealthChecksForAllHosts() {
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(addrs) != 1 {
|
if addr.PortRangeSize() != 1 {
|
||||||
h.HealthChecks.Active.logger.Error("multiple addresses (upstream must map to only one address)",
|
h.HealthChecks.Active.logger.Error("multiple addresses (upstream must map to only one address)",
|
||||||
zap.String("address", networkAddr),
|
zap.String("address", networkAddr),
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hostAddr := addrs[0]
|
hostAddr := addr.JoinHostPort(0)
|
||||||
if network == "unix" || network == "unixgram" || network == "unixpacket" {
|
if addr.Network == "unix" || addr.Network == "unixgram" || addr.Network == "unixpacket" {
|
||||||
// this will be used as the Host portion of a http.Request URL, and
|
// this will be used as the Host portion of a http.Request URL, and
|
||||||
// paths to socket files would produce an error when creating URL,
|
// paths to socket files would produce an error when creating URL,
|
||||||
// so use a fake Host value instead; unix sockets are usually local
|
// so use a fake Host value instead; unix sockets are usually local
|
||||||
hostAddr = "localhost"
|
hostAddr = "localhost"
|
||||||
}
|
}
|
||||||
err = h.doActiveHealthCheck(DialInfo{Network: network, Address: addrs[0]}, hostAddr, host)
|
err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: hostAddr}, hostAddr, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.HealthChecks.Active.logger.Error("active health check failed",
|
h.HealthChecks.Active.logger.Error("active health check failed",
|
||||||
zap.String("address", networkAddr),
|
zap.String("address", networkAddr),
|
||||||
|
|
|
@ -16,8 +16,7 @@ package reverseproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/caddyserver/caddy/v2"
|
"github.com/caddyserver/caddy/v2"
|
||||||
|
@ -193,27 +192,20 @@ func (di DialInfo) String() string {
|
||||||
// the given Replacer. Note that the returned value is not a pointer.
|
// the given Replacer. Note that the returned value is not a pointer.
|
||||||
func fillDialInfo(upstream *Upstream, repl caddy.Replacer) (DialInfo, error) {
|
func fillDialInfo(upstream *Upstream, repl caddy.Replacer) (DialInfo, error) {
|
||||||
dial := repl.ReplaceAll(upstream.Dial, "")
|
dial := repl.ReplaceAll(upstream.Dial, "")
|
||||||
netw, addrs, err := caddy.ParseNetworkAddress(dial)
|
addr, err := caddy.ParseNetworkAddress(dial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", upstream.Dial, dial, err)
|
return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", upstream.Dial, dial, err)
|
||||||
}
|
}
|
||||||
if len(addrs) != 1 {
|
if numPorts := addr.PortRangeSize(); numPorts != 1 {
|
||||||
return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d",
|
return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d",
|
||||||
upstream.Dial, dial, len(addrs))
|
upstream.Dial, dial, numPorts)
|
||||||
}
|
|
||||||
var dialHost, dialPort string
|
|
||||||
if !strings.Contains(netw, "unix") {
|
|
||||||
dialHost, dialPort, err = net.SplitHostPort(addrs[0])
|
|
||||||
if err != nil {
|
|
||||||
dialHost = addrs[0] // assume there was no port
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return DialInfo{
|
return DialInfo{
|
||||||
Upstream: upstream,
|
Upstream: upstream,
|
||||||
Network: netw,
|
Network: addr.Network,
|
||||||
Address: addrs[0],
|
Address: addr.JoinHostPort(0),
|
||||||
Host: dialHost,
|
Host: addr.Host,
|
||||||
Port: dialPort,
|
Port: strconv.Itoa(int(addr.StartPort)),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -242,40 +242,44 @@ func (s *Server) enforcementHandler(w http.ResponseWriter, r *http.Request, next
|
||||||
// listeners in s that use a port which is not otherPort.
|
// listeners in s that use a port which is not otherPort.
|
||||||
func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool {
|
func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool {
|
||||||
for _, lnAddr := range s.Listen {
|
for _, lnAddr := range s.Listen {
|
||||||
_, addrs, err := caddy.ParseNetworkAddress(lnAddr)
|
laddrs, err := caddy.ParseNetworkAddress(lnAddr)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
for _, a := range addrs {
|
continue
|
||||||
_, port, err := net.SplitHostPort(a)
|
}
|
||||||
if err == nil && port != strconv.Itoa(otherPort) {
|
if uint(otherPort) > laddrs.EndPort || uint(otherPort) < laddrs.StartPort {
|
||||||
return true
|
return true
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hasListenerAddress returns true if s has a listener
|
||||||
|
// at the given address fullAddr. Currently, fullAddr
|
||||||
|
// must represent exactly one socket address (port
|
||||||
|
// ranges are not supported)
|
||||||
func (s *Server) hasListenerAddress(fullAddr string) bool {
|
func (s *Server) hasListenerAddress(fullAddr string) bool {
|
||||||
netw, addrs, err := caddy.ParseNetworkAddress(fullAddr)
|
laddrs, err := caddy.ParseNetworkAddress(fullAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if len(addrs) != 1 {
|
if laddrs.PortRangeSize() != 1 {
|
||||||
return false
|
return false // TODO: support port ranges
|
||||||
}
|
}
|
||||||
addr := addrs[0]
|
|
||||||
for _, lnAddr := range s.Listen {
|
for _, lnAddr := range s.Listen {
|
||||||
thisNetw, thisAddrs, err := caddy.ParseNetworkAddress(lnAddr)
|
thisAddrs, err := caddy.ParseNetworkAddress(lnAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if thisNetw != netw {
|
if thisAddrs.Network != laddrs.Network {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, a := range thisAddrs {
|
|
||||||
if a == addr {
|
// host must be the same and port must fall within port range
|
||||||
return true
|
if (thisAddrs.Host == laddrs.Host) &&
|
||||||
}
|
(laddrs.StartPort <= thisAddrs.EndPort) &&
|
||||||
|
(laddrs.StartPort >= thisAddrs.StartPort) {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
|
Loading…
Reference in a new issue