admin: Always enforce Host header checks

With a simple heuristic for loopback addresses, we can enable this by
default without adding unnecessary inconvenience.
This commit is contained in:
Matthew Holt 2020-04-10 17:31:38 -06:00
parent d3383ced2a
commit a3bdc22234
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5
3 changed files with 70 additions and 27 deletions

View file

@ -21,6 +21,7 @@ import (
"expvar" "expvar"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/http/pprof" "net/http/pprof"
"net/url" "net/url"
@ -78,27 +79,27 @@ type ConfigSettings struct {
// listenAddr extracts a singular listen address from ac.Listen, // listenAddr extracts a singular listen address from ac.Listen,
// returning the network and the address of the listener. // returning the network and the address of the listener.
func (admin AdminConfig) listenAddr() (string, string, error) { func (admin AdminConfig) listenAddr() (NetworkAddress, error) {
input := admin.Listen input := admin.Listen
if input == "" { if input == "" {
input = DefaultAdminListen input = DefaultAdminListen
} }
listenAddr, err := ParseNetworkAddress(input) listenAddr, err := ParseNetworkAddress(input)
if err != nil { if err != nil {
return "", "", fmt.Errorf("parsing admin listener address: %v", err) return NetworkAddress{}, fmt.Errorf("parsing admin listener address: %v", err)
} }
if listenAddr.PortRangeSize() != 1 { if listenAddr.PortRangeSize() != 1 {
return "", "", fmt.Errorf("admin endpoint must have exactly one address; cannot listen on %v", listenAddr) return NetworkAddress{}, fmt.Errorf("admin endpoint must have exactly one address; cannot listen on %v", listenAddr)
} }
return listenAddr.Network, listenAddr.JoinHostPort(0), nil return listenAddr, nil
} }
// newAdminHandler reads admin's config and returns an http.Handler suitable // newAdminHandler reads admin's config and returns an http.Handler suitable
// for use in an admin endpoint server, which will be listening on listenAddr. // for use in an admin endpoint server, which will be listening on listenAddr.
func (admin AdminConfig) newAdminHandler(listenAddr string) adminHandler { func (admin AdminConfig) newAdminHandler(addr NetworkAddress) adminHandler {
muxWrap := adminHandler{ muxWrap := adminHandler{
enforceOrigin: admin.EnforceOrigin, enforceOrigin: admin.EnforceOrigin,
allowedOrigins: admin.allowedOrigins(listenAddr), allowedOrigins: admin.allowedOrigins(addr),
mux: http.NewServeMux(), mux: http.NewServeMux(),
} }
@ -140,14 +141,30 @@ func (admin AdminConfig) newAdminHandler(listenAddr string) adminHandler {
// If admin.Origins is nil (null), the provided listen address // If admin.Origins is nil (null), the provided listen address
// will be used as the default origin. If admin.Origins is // will be used as the default origin. If admin.Origins is
// empty, no origins will be allowed, effectively bricking the // empty, no origins will be allowed, effectively bricking the
// endpoint, but whatever. // endpoint for non-unix-socket endpoints, but whatever.
func (admin AdminConfig) allowedOrigins(listen string) []string { func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []string {
uniqueOrigins := make(map[string]struct{}) uniqueOrigins := make(map[string]struct{})
for _, o := range admin.Origins { for _, o := range admin.Origins {
uniqueOrigins[o] = struct{}{} uniqueOrigins[o] = struct{}{}
} }
if admin.Origins == nil { if admin.Origins == nil {
uniqueOrigins[listen] = struct{}{} if addr.isLoopback() {
if addr.IsUnixNetwork() {
// RFC 2616, Section 14.26:
// "A client MUST include a Host header field in all HTTP/1.1 request
// messages. If the requested URI does not include an Internet host
// name for the service being requested, then the Host header field MUST
// be given with an empty value."
uniqueOrigins[""] = struct{}{}
} else {
uniqueOrigins[net.JoinHostPort("localhost", addr.port())] = struct{}{}
uniqueOrigins[net.JoinHostPort("::1", addr.port())] = struct{}{}
uniqueOrigins[net.JoinHostPort("127.0.0.1", addr.port())] = struct{}{}
}
}
if !addr.IsUnixNetwork() {
uniqueOrigins[addr.JoinHostPort(0)] = struct{}{}
}
} }
allowed := make([]string, 0, len(uniqueOrigins)) allowed := make([]string, 0, len(uniqueOrigins))
for origin := range uniqueOrigins { for origin := range uniqueOrigins {
@ -195,14 +212,14 @@ func replaceAdmin(cfg *Config) error {
} }
// extract a singular listener address // extract a singular listener address
netw, addr, err := adminConfig.listenAddr() addr, err := adminConfig.listenAddr()
if err != nil { if err != nil {
return err return err
} }
handler := adminConfig.newAdminHandler(addr) handler := adminConfig.newAdminHandler(addr)
ln, err := Listen(netw, addr) ln, err := Listen(addr.Network, addr.JoinHostPort(0))
if err != nil { if err != nil {
return err return err
} }
@ -219,7 +236,7 @@ func replaceAdmin(cfg *Config) error {
Log().Named("admin").Info( Log().Named("admin").Info(
"admin endpoint started", "admin endpoint started",
zap.String("address", addr), zap.String("address", addr.String()),
zap.Bool("enforce_origin", adminConfig.EnforceOrigin), zap.Bool("enforce_origin", adminConfig.EnforceOrigin),
zap.Strings("origins", handler.allowedOrigins), zap.Strings("origins", handler.allowedOrigins),
) )
@ -263,6 +280,7 @@ type adminHandler struct {
func (h adminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h adminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Log().Named("admin.api").Info("received request", Log().Named("admin.api").Info("received request",
zap.String("method", r.Method), zap.String("method", r.Method),
zap.String("host", r.Host),
zap.String("uri", r.RequestURI), zap.String("uri", r.RequestURI),
zap.String("remote_addr", r.RemoteAddr), zap.String("remote_addr", r.RemoteAddr),
zap.Reflect("headers", r.Header), zap.Reflect("headers", r.Header),
@ -274,14 +292,14 @@ func (h adminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// be called more than once per request, for example if a request // be called more than once per request, for example if a request
// is rewritten (i.e. internal redirect). // is rewritten (i.e. internal redirect).
func (h adminHandler) serveHTTP(w http.ResponseWriter, r *http.Request) { func (h adminHandler) serveHTTP(w http.ResponseWriter, r *http.Request) {
if h.enforceOrigin { // DNS rebinding mitigation
// DNS rebinding mitigation err := h.checkHost(r)
err := h.checkHost(r) if err != nil {
if err != nil { h.handleError(w, r, err)
h.handleError(w, r, err) return
return }
}
if h.enforceOrigin {
// cross-site mitigation // cross-site mitigation
origin, err := h.checkOrigin(r) origin, err := h.checkOrigin(r)
if err != nil { if err != nil {

View file

@ -289,14 +289,31 @@ func (na NetworkAddress) PortRangeSize() uint {
return (na.EndPort - na.StartPort) + 1 return (na.EndPort - na.StartPort) + 1
} }
// String reconstructs the address string to the form expected func (na NetworkAddress) isLoopback() bool {
// by ParseNetworkAddress(). if na.IsUnixNetwork() {
func (na NetworkAddress) String() string { return true
port := strconv.FormatUint(uint64(na.StartPort), 10)
if na.StartPort != na.EndPort {
port += "-" + strconv.FormatUint(uint64(na.EndPort), 10)
} }
return JoinNetworkAddress(na.Network, na.Host, port) if na.Host == "localhost" {
return true
}
if ip := net.ParseIP(na.Host); ip != nil {
return ip.IsLoopback()
}
return false
}
func (na NetworkAddress) port() string {
if na.StartPort == na.EndPort {
return strconv.FormatUint(uint64(na.StartPort), 10)
}
return fmt.Sprintf("%d-%d", na.StartPort, na.EndPort)
}
// String reconstructs the address string to the form expected
// by ParseNetworkAddress(). If the address is a unix socket,
// any non-zero port will be dropped.
func (na NetworkAddress) String() string {
return JoinNetworkAddress(na.Network, na.Host, na.port())
} }
func isUnixNetwork(netw string) bool { func isUnixNetwork(netw string) bool {
@ -378,7 +395,7 @@ func JoinNetworkAddress(network, host, port string) string {
if network != "" { if network != "" {
a = network + "/" a = network + "/"
} }
if host != "" && port == "" { if (host != "" && port == "") || isUnixNetwork(network) {
a += host a += host
} else if port != "" { } else if port != "" {
a += net.JoinHostPort(host, port) a += net.JoinHostPort(host, port)

View file

@ -138,6 +138,14 @@ func TestJoinNetworkAddress(t *testing.T) {
network: "unix", host: "/foo/bar", port: "", network: "unix", host: "/foo/bar", port: "",
expect: "unix//foo/bar", expect: "unix//foo/bar",
}, },
{
network: "unix", host: "/foo/bar", port: "0",
expect: "unix//foo/bar",
},
{
network: "unix", host: "/foo/bar", port: "1234",
expect: "unix//foo/bar",
},
{ {
network: "", host: "::1", port: "1234", network: "", host: "::1", port: "1234",
expect: "[::1]:1234", expect: "[::1]:1234",