diff --git a/admin.go b/admin.go index 0b509f28..7217eb71 100644 --- a/admin.go +++ b/admin.go @@ -21,6 +21,7 @@ import ( "expvar" "fmt" "io" + "net" "net/http" "net/http/pprof" "net/url" @@ -78,27 +79,27 @@ type ConfigSettings struct { // listenAddr extracts a singular listen address from ac.Listen, // returning the network and the address of the listener. -func (admin AdminConfig) listenAddr() (string, string, error) { +func (admin AdminConfig) listenAddr() (NetworkAddress, error) { input := admin.Listen if input == "" { input = DefaultAdminListen } listenAddr, err := ParseNetworkAddress(input) if err != nil { - return "", "", fmt.Errorf("parsing admin listener address: %v", err) + return NetworkAddress{}, fmt.Errorf("parsing admin listener address: %v", err) } if listenAddr.PortRangeSize() != 1 { - return "", "", fmt.Errorf("admin endpoint must have exactly one address; cannot listen on %v", listenAddr) + return NetworkAddress{}, fmt.Errorf("admin endpoint must have exactly one address; cannot listen on %v", listenAddr) } - return listenAddr.Network, listenAddr.JoinHostPort(0), nil + return listenAddr, nil } // newAdminHandler reads admin's config and returns an http.Handler suitable // for use in an admin endpoint server, which will be listening on listenAddr. -func (admin AdminConfig) newAdminHandler(listenAddr string) adminHandler { +func (admin AdminConfig) newAdminHandler(addr NetworkAddress) adminHandler { muxWrap := adminHandler{ enforceOrigin: admin.EnforceOrigin, - allowedOrigins: admin.allowedOrigins(listenAddr), + allowedOrigins: admin.allowedOrigins(addr), mux: http.NewServeMux(), } @@ -140,14 +141,30 @@ func (admin AdminConfig) newAdminHandler(listenAddr string) adminHandler { // If admin.Origins is nil (null), the provided listen address // will be used as the default origin. If admin.Origins is // empty, no origins will be allowed, effectively bricking the -// endpoint, but whatever. -func (admin AdminConfig) allowedOrigins(listen string) []string { +// endpoint for non-unix-socket endpoints, but whatever. +func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []string { uniqueOrigins := make(map[string]struct{}) for _, o := range admin.Origins { uniqueOrigins[o] = struct{}{} } if admin.Origins == nil { - uniqueOrigins[listen] = struct{}{} + if addr.isLoopback() { + if addr.IsUnixNetwork() { + // RFC 2616, Section 14.26: + // "A client MUST include a Host header field in all HTTP/1.1 request + // messages. If the requested URI does not include an Internet host + // name for the service being requested, then the Host header field MUST + // be given with an empty value." + uniqueOrigins[""] = struct{}{} + } else { + uniqueOrigins[net.JoinHostPort("localhost", addr.port())] = struct{}{} + uniqueOrigins[net.JoinHostPort("::1", addr.port())] = struct{}{} + uniqueOrigins[net.JoinHostPort("127.0.0.1", addr.port())] = struct{}{} + } + } + if !addr.IsUnixNetwork() { + uniqueOrigins[addr.JoinHostPort(0)] = struct{}{} + } } allowed := make([]string, 0, len(uniqueOrigins)) for origin := range uniqueOrigins { @@ -195,14 +212,14 @@ func replaceAdmin(cfg *Config) error { } // extract a singular listener address - netw, addr, err := adminConfig.listenAddr() + addr, err := adminConfig.listenAddr() if err != nil { return err } handler := adminConfig.newAdminHandler(addr) - ln, err := Listen(netw, addr) + ln, err := Listen(addr.Network, addr.JoinHostPort(0)) if err != nil { return err } @@ -219,7 +236,7 @@ func replaceAdmin(cfg *Config) error { Log().Named("admin").Info( "admin endpoint started", - zap.String("address", addr), + zap.String("address", addr.String()), zap.Bool("enforce_origin", adminConfig.EnforceOrigin), zap.Strings("origins", handler.allowedOrigins), ) @@ -263,6 +280,7 @@ type adminHandler struct { func (h adminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { Log().Named("admin.api").Info("received request", zap.String("method", r.Method), + zap.String("host", r.Host), zap.String("uri", r.RequestURI), zap.String("remote_addr", r.RemoteAddr), zap.Reflect("headers", r.Header), @@ -274,14 +292,14 @@ func (h adminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // be called more than once per request, for example if a request // is rewritten (i.e. internal redirect). func (h adminHandler) serveHTTP(w http.ResponseWriter, r *http.Request) { - if h.enforceOrigin { - // DNS rebinding mitigation - err := h.checkHost(r) - if err != nil { - h.handleError(w, r, err) - return - } + // DNS rebinding mitigation + err := h.checkHost(r) + if err != nil { + h.handleError(w, r, err) + return + } + if h.enforceOrigin { // cross-site mitigation origin, err := h.checkOrigin(r) if err != nil { diff --git a/listeners.go b/listeners.go index e1fd48c9..bfbe6dd1 100644 --- a/listeners.go +++ b/listeners.go @@ -289,14 +289,31 @@ func (na NetworkAddress) PortRangeSize() uint { return (na.EndPort - na.StartPort) + 1 } -// String reconstructs the address string to the form expected -// by ParseNetworkAddress(). -func (na NetworkAddress) String() string { - port := strconv.FormatUint(uint64(na.StartPort), 10) - if na.StartPort != na.EndPort { - port += "-" + strconv.FormatUint(uint64(na.EndPort), 10) +func (na NetworkAddress) isLoopback() bool { + if na.IsUnixNetwork() { + return true } - return JoinNetworkAddress(na.Network, na.Host, port) + if na.Host == "localhost" { + return true + } + if ip := net.ParseIP(na.Host); ip != nil { + return ip.IsLoopback() + } + return false +} + +func (na NetworkAddress) port() string { + if na.StartPort == na.EndPort { + return strconv.FormatUint(uint64(na.StartPort), 10) + } + return fmt.Sprintf("%d-%d", na.StartPort, na.EndPort) +} + +// String reconstructs the address string to the form expected +// by ParseNetworkAddress(). If the address is a unix socket, +// any non-zero port will be dropped. +func (na NetworkAddress) String() string { + return JoinNetworkAddress(na.Network, na.Host, na.port()) } func isUnixNetwork(netw string) bool { @@ -378,7 +395,7 @@ func JoinNetworkAddress(network, host, port string) string { if network != "" { a = network + "/" } - if host != "" && port == "" { + if (host != "" && port == "") || isUnixNetwork(network) { a += host } else if port != "" { a += net.JoinHostPort(host, port) diff --git a/listeners_test.go b/listeners_test.go index 2828a3b4..b75e2dce 100644 --- a/listeners_test.go +++ b/listeners_test.go @@ -138,6 +138,14 @@ func TestJoinNetworkAddress(t *testing.T) { network: "unix", host: "/foo/bar", port: "", expect: "unix//foo/bar", }, + { + network: "unix", host: "/foo/bar", port: "0", + expect: "unix//foo/bar", + }, + { + network: "unix", host: "/foo/bar", port: "1234", + expect: "unix//foo/bar", + }, { network: "", host: "::1", port: "1234", expect: "[::1]:1234",