From 6ebf619430dff7d214cd2b72add134851629b67e Mon Sep 17 00:00:00 2001 From: sigoden Date: Fri, 11 Nov 2022 08:57:44 +0800 Subject: [PATCH] feat: support unix sockets (#145) --- README.md | 11 +++-- src/args.rs | 26 +++++++--- src/main.rs | 134 ++++++++++++++++++++++++++++++-------------------- src/server.rs | 6 ++- src/unix.rs | 31 ++++++++++++ 5 files changed, 143 insertions(+), 65 deletions(-) create mode 100644 src/unix.rs diff --git a/README.md b/README.md index 4aef7a4..1dfabaf 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ ARGS: Specific path to serve [default: .] OPTIONS: - -b, --bind ... Specify bind address + -b, --bind ... Specify bind address or unix socket -p, --port Specify port to listen on [default: 5000] --path-prefix Specify a path prefix --hidden Hide paths from directory listings, separated by `,` @@ -123,10 +123,15 @@ Require username/password dufs -a /@admin:123 ``` -Listen on a specific port +Listen on specific host:ip ``` -dufs -p 80 +dufs -b 127.0.0.1 -p 80 +``` + +Listen on unix socket +``` +dufs -b /tmp/dufs.socket ``` Use https diff --git a/src/args.rs b/src/args.rs index 358a23e..9799d09 100644 --- a/src/args.rs +++ b/src/args.rs @@ -28,7 +28,7 @@ pub fn build_cli() -> Command<'static> { Arg::new("bind") .short('b') .long("bind") - .help("Specify bind address") + .help("Specify bind address or unix socket") .multiple_values(true) .value_delimiter(',') .action(ArgAction::Append) @@ -168,7 +168,7 @@ pub fn print_completions(gen: G, cmd: &mut Command) { #[derive(Debug)] pub struct Args { - pub addrs: Vec, + pub addrs: Vec, pub port: u16, pub path: PathBuf, pub path_is_file: bool, @@ -204,7 +204,7 @@ impl Args { .values_of("bind") .map(|v| v.collect()) .unwrap_or_else(|| vec!["0.0.0.0", "::"]); - let addrs: Vec = Args::parse_addrs(&addrs)?; + let addrs: Vec = Args::parse_addrs(&addrs)?; let path = Args::parse_path(matches.value_of_os("root").unwrap_or_default())?; let path_is_file = path.metadata()?.is_file(); let path_prefix = matches @@ -281,23 +281,27 @@ impl Args { }) } - fn parse_addrs(addrs: &[&str]) -> BoxResult> { - let mut ip_addrs = vec![]; + fn parse_addrs(addrs: &[&str]) -> BoxResult> { + let mut bind_addrs = vec![]; let mut invalid_addrs = vec![]; for addr in addrs { match addr.parse::() { Ok(v) => { - ip_addrs.push(v); + bind_addrs.push(BindAddr::Address(v)); } Err(_) => { - invalid_addrs.push(*addr); + if cfg!(unix) { + bind_addrs.push(BindAddr::Path(PathBuf::from(addr))); + } else { + invalid_addrs.push(*addr); + } } } } if !invalid_addrs.is_empty() { return Err(format!("Invalid bind address `{}`", invalid_addrs.join(",")).into()); } - Ok(ip_addrs) + Ok(bind_addrs) } fn parse_path>(path: P) -> BoxResult { @@ -322,3 +326,9 @@ impl Args { Ok(path) } } + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum BindAddr { + Address(IpAddr), + Path(PathBuf), +} diff --git a/src/main.rs b/src/main.rs index ed85e8a..5df3415 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,8 @@ mod server; mod streamer; #[cfg(feature = "tls")] mod tls; +#[cfg(unix)] +mod unix; mod utils; #[macro_use] @@ -20,6 +22,7 @@ use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use args::BindAddr; use clap_complete::Shell; use futures::future::join_all; use tokio::net::TcpListener; @@ -75,11 +78,9 @@ fn serve( let inner = Arc::new(Server::new(args.clone(), running)); let mut handles = vec![]; let port = args.port; - for ip in args.addrs.iter() { + for bind_addr in args.addrs.iter() { let inner = inner.clone(); - let incoming = create_addr_incoming(SocketAddr::new(*ip, port)) - .map_err(|e| format!("Failed to bind `{}:{}`, {}", ip, port, e))?; - let serve_func = move |remote_addr: SocketAddr| { + let serve_func = move |remote_addr: Option| { let inner = inner.clone(); async move { Ok::<_, hyper::Error>(service_fn(move |req: Request| { @@ -88,35 +89,57 @@ fn serve( })) } }; - match args.tls.as_ref() { - #[cfg(feature = "tls")] - Some((certs, key)) => { - let config = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(certs.clone(), key.clone())?; - let config = Arc::new(config); - let accepter = TlsAcceptor::new(config.clone(), incoming); - let new_service = make_service_fn(move |socket: &TlsStream| { - let remote_addr = socket.remote_addr(); - serve_func(remote_addr) - }); - let server = tokio::spawn(hyper::Server::builder(accepter).serve(new_service)); - handles.push(server); + match bind_addr { + BindAddr::Address(ip) => { + let incoming = create_addr_incoming(SocketAddr::new(*ip, port)) + .map_err(|e| format!("Failed to bind `{}:{}`, {}", ip, port, e))?; + match args.tls.as_ref() { + #[cfg(feature = "tls")] + Some((certs, key)) => { + let config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs.clone(), key.clone())?; + let config = Arc::new(config); + let accepter = TlsAcceptor::new(config.clone(), incoming); + let new_service = make_service_fn(move |socket: &TlsStream| { + let remote_addr = socket.remote_addr(); + serve_func(Some(remote_addr)) + }); + let server = + tokio::spawn(hyper::Server::builder(accepter).serve(new_service)); + handles.push(server); + } + #[cfg(not(feature = "tls"))] + Some(_) => { + unreachable!() + } + None => { + let new_service = make_service_fn(move |socket: &AddrStream| { + let remote_addr = socket.remote_addr(); + serve_func(Some(remote_addr)) + }); + let server = + tokio::spawn(hyper::Server::builder(incoming).serve(new_service)); + handles.push(server); + } + }; } - #[cfg(not(feature = "tls"))] - Some(_) => { - unreachable!() + BindAddr::Path(path) => { + if path.exists() { + std::fs::remove_file(path)?; + } + #[cfg(unix)] + { + let listener = tokio::net::UnixListener::bind(path) + .map_err(|e| format!("Failed to bind `{}`, {}", path.display(), e))?; + let acceptor = unix::UnixAcceptor::from_listener(listener); + let new_service = make_service_fn(move |_| serve_func(None)); + let server = tokio::spawn(hyper::Server::builder(acceptor).serve(new_service)); + handles.push(server); + } } - None => { - let new_service = make_service_fn(move |socket: &AddrStream| { - let remote_addr = socket.remote_addr(); - serve_func(remote_addr) - }); - let server = tokio::spawn(hyper::Server::builder(incoming).serve(new_service)); - handles.push(server); - } - }; + } } Ok(handles) } @@ -137,17 +160,22 @@ fn create_addr_incoming(addr: SocketAddr) -> BoxResult { } fn print_listening(args: Arc) -> BoxResult<()> { - let mut addrs = vec![]; + let mut bind_addrs = vec![]; let (mut ipv4, mut ipv6) = (false, false); - for ip in args.addrs.iter() { - if ip.is_unspecified() { - if ip.is_ipv6() { - ipv6 = true; - } else { - ipv4 = true; + for bind_addr in args.addrs.iter() { + match bind_addr { + BindAddr::Address(ip) => { + if ip.is_unspecified() { + if ip.is_ipv6() { + ipv6 = true; + } else { + ipv4 = true; + } + } else { + bind_addrs.push(bind_addr.clone()); + } } - } else { - addrs.push(*ip); + _ => bind_addrs.push(bind_addr.clone()), } } if ipv4 || ipv6 { @@ -156,25 +184,27 @@ fn print_listening(args: Arc) -> BoxResult<()> { for iface in ifaces.into_iter() { let local_ip = iface.ip(); if ipv4 && local_ip.is_ipv4() { - addrs.push(local_ip) + bind_addrs.push(BindAddr::Address(local_ip)) } if ipv6 && local_ip.is_ipv6() { - addrs.push(local_ip) + bind_addrs.push(BindAddr::Address(local_ip)) } } } - addrs.sort_unstable(); - let urls = addrs + bind_addrs.sort_unstable(); + let urls = bind_addrs .into_iter() - .map(|addr| match addr { - IpAddr::V4(_) => format!("{}:{}", addr, args.port), - IpAddr::V6(_) => format!("[{}]:{}", addr, args.port), + .map(|bind_addr| match bind_addr { + BindAddr::Address(addr) => { + let addr = match addr { + IpAddr::V4(_) => format!("{}:{}", addr, args.port), + IpAddr::V6(_) => format!("[{}]:{}", addr, args.port), + }; + let protocol = if args.tls.is_some() { "https" } else { "http" }; + format!("{}://{}{}", protocol, addr, args.uri_prefix) + } + BindAddr::Path(path) => path.display().to_string(), }) - .map(|addr| match &args.tls { - Some(_) => format!("https://{}", addr), - None => format!("http://{}", addr), - }) - .map(|url| format!("{}{}", url, args.uri_prefix)) .collect::>(); if urls.len() == 1 { diff --git a/src/server.rs b/src/server.rs index a32b219..291ec0a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -84,13 +84,15 @@ impl Server { pub async fn call( self: Arc, req: Request, - addr: SocketAddr, + addr: Option, ) -> Result { let uri = req.uri().clone(); let assets_prefix = self.assets_prefix.clone(); let enable_cors = self.args.enable_cors; let mut http_log_data = self.args.log_http.data(&req, &self.args); - http_log_data.insert("remote_addr".to_string(), addr.ip().to_string()); + if let Some(addr) = addr { + http_log_data.insert("remote_addr".to_string(), addr.ip().to_string()); + } let mut res = match self.clone().handle(req).await { Ok(res) => { diff --git a/src/unix.rs b/src/unix.rs new file mode 100644 index 0000000..b8b1710 --- /dev/null +++ b/src/unix.rs @@ -0,0 +1,31 @@ +use hyper::server::accept::Accept; +use tokio::net::UnixListener; + +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub struct UnixAcceptor { + inner: UnixListener, +} + +impl UnixAcceptor { + pub fn from_listener(listener: UnixListener) -> Self { + Self { inner: listener } + } +} + +impl Accept for UnixAcceptor { + type Conn = tokio::net::UnixStream; + type Error = std::io::Error; + + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.inner.poll_accept(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok((socket, _addr))) => Poll::Ready(Some(Ok(socket))), + Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))), + } + } +}