feat: support unix sockets (#145)

This commit is contained in:
sigoden 2022-11-11 08:57:44 +08:00 committed by GitHub
parent 8b4727c3a4
commit 6ebf619430
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 143 additions and 65 deletions

View file

@ -49,7 +49,7 @@ ARGS:
<root> Specific path to serve [default: .]
OPTIONS:
-b, --bind <addr>... Specify bind address
-b, --bind <addr>... Specify bind address or unix socket
-p, --port <port> Specify port to listen on [default: 5000]
--path-prefix <path> Specify a path prefix
--hidden <value> Hide paths from directory listings, separated by `,`
@ -123,10 +123,15 @@ Require username/password
dufs -a /@admin:123
```
Listen on a specific port
Listen on specific host:ip
```
dufs -p 80
dufs -b 127.0.0.1 -p 80
```
Listen on unix socket
```
dufs -b /tmp/dufs.socket
```
Use https

View file

@ -28,7 +28,7 @@ pub fn build_cli() -> Command<'static> {
Arg::new("bind")
.short('b')
.long("bind")
.help("Specify bind address")
.help("Specify bind address or unix socket")
.multiple_values(true)
.value_delimiter(',')
.action(ArgAction::Append)
@ -168,7 +168,7 @@ pub fn print_completions<G: Generator>(gen: G, cmd: &mut Command) {
#[derive(Debug)]
pub struct Args {
pub addrs: Vec<IpAddr>,
pub addrs: Vec<BindAddr>,
pub port: u16,
pub path: PathBuf,
pub path_is_file: bool,
@ -204,7 +204,7 @@ impl Args {
.values_of("bind")
.map(|v| v.collect())
.unwrap_or_else(|| vec!["0.0.0.0", "::"]);
let addrs: Vec<IpAddr> = Args::parse_addrs(&addrs)?;
let addrs: Vec<BindAddr> = Args::parse_addrs(&addrs)?;
let path = Args::parse_path(matches.value_of_os("root").unwrap_or_default())?;
let path_is_file = path.metadata()?.is_file();
let path_prefix = matches
@ -281,23 +281,27 @@ impl Args {
})
}
fn parse_addrs(addrs: &[&str]) -> BoxResult<Vec<IpAddr>> {
let mut ip_addrs = vec![];
fn parse_addrs(addrs: &[&str]) -> BoxResult<Vec<BindAddr>> {
let mut bind_addrs = vec![];
let mut invalid_addrs = vec![];
for addr in addrs {
match addr.parse::<IpAddr>() {
Ok(v) => {
ip_addrs.push(v);
bind_addrs.push(BindAddr::Address(v));
}
Err(_) => {
invalid_addrs.push(*addr);
if cfg!(unix) {
bind_addrs.push(BindAddr::Path(PathBuf::from(addr)));
} else {
invalid_addrs.push(*addr);
}
}
}
}
if !invalid_addrs.is_empty() {
return Err(format!("Invalid bind address `{}`", invalid_addrs.join(",")).into());
}
Ok(ip_addrs)
Ok(bind_addrs)
}
fn parse_path<P: AsRef<Path>>(path: P) -> BoxResult<PathBuf> {
@ -322,3 +326,9 @@ impl Args {
Ok(path)
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum BindAddr {
Address(IpAddr),
Path(PathBuf),
}

View file

@ -6,6 +6,8 @@ mod server;
mod streamer;
#[cfg(feature = "tls")]
mod tls;
#[cfg(unix)]
mod unix;
mod utils;
#[macro_use]
@ -20,6 +22,7 @@ use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use args::BindAddr;
use clap_complete::Shell;
use futures::future::join_all;
use tokio::net::TcpListener;
@ -75,11 +78,9 @@ fn serve(
let inner = Arc::new(Server::new(args.clone(), running));
let mut handles = vec![];
let port = args.port;
for ip in args.addrs.iter() {
for bind_addr in args.addrs.iter() {
let inner = inner.clone();
let incoming = create_addr_incoming(SocketAddr::new(*ip, port))
.map_err(|e| format!("Failed to bind `{}:{}`, {}", ip, port, e))?;
let serve_func = move |remote_addr: SocketAddr| {
let serve_func = move |remote_addr: Option<SocketAddr>| {
let inner = inner.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |req: Request| {
@ -88,35 +89,57 @@ fn serve(
}))
}
};
match args.tls.as_ref() {
#[cfg(feature = "tls")]
Some((certs, key)) => {
let config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs.clone(), key.clone())?;
let config = Arc::new(config);
let accepter = TlsAcceptor::new(config.clone(), incoming);
let new_service = make_service_fn(move |socket: &TlsStream| {
let remote_addr = socket.remote_addr();
serve_func(remote_addr)
});
let server = tokio::spawn(hyper::Server::builder(accepter).serve(new_service));
handles.push(server);
match bind_addr {
BindAddr::Address(ip) => {
let incoming = create_addr_incoming(SocketAddr::new(*ip, port))
.map_err(|e| format!("Failed to bind `{}:{}`, {}", ip, port, e))?;
match args.tls.as_ref() {
#[cfg(feature = "tls")]
Some((certs, key)) => {
let config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs.clone(), key.clone())?;
let config = Arc::new(config);
let accepter = TlsAcceptor::new(config.clone(), incoming);
let new_service = make_service_fn(move |socket: &TlsStream| {
let remote_addr = socket.remote_addr();
serve_func(Some(remote_addr))
});
let server =
tokio::spawn(hyper::Server::builder(accepter).serve(new_service));
handles.push(server);
}
#[cfg(not(feature = "tls"))]
Some(_) => {
unreachable!()
}
None => {
let new_service = make_service_fn(move |socket: &AddrStream| {
let remote_addr = socket.remote_addr();
serve_func(Some(remote_addr))
});
let server =
tokio::spawn(hyper::Server::builder(incoming).serve(new_service));
handles.push(server);
}
};
}
#[cfg(not(feature = "tls"))]
Some(_) => {
unreachable!()
BindAddr::Path(path) => {
if path.exists() {
std::fs::remove_file(path)?;
}
#[cfg(unix)]
{
let listener = tokio::net::UnixListener::bind(path)
.map_err(|e| format!("Failed to bind `{}`, {}", path.display(), e))?;
let acceptor = unix::UnixAcceptor::from_listener(listener);
let new_service = make_service_fn(move |_| serve_func(None));
let server = tokio::spawn(hyper::Server::builder(acceptor).serve(new_service));
handles.push(server);
}
}
None => {
let new_service = make_service_fn(move |socket: &AddrStream| {
let remote_addr = socket.remote_addr();
serve_func(remote_addr)
});
let server = tokio::spawn(hyper::Server::builder(incoming).serve(new_service));
handles.push(server);
}
};
}
}
Ok(handles)
}
@ -137,17 +160,22 @@ fn create_addr_incoming(addr: SocketAddr) -> BoxResult<AddrIncoming> {
}
fn print_listening(args: Arc<Args>) -> BoxResult<()> {
let mut addrs = vec![];
let mut bind_addrs = vec![];
let (mut ipv4, mut ipv6) = (false, false);
for ip in args.addrs.iter() {
if ip.is_unspecified() {
if ip.is_ipv6() {
ipv6 = true;
} else {
ipv4 = true;
for bind_addr in args.addrs.iter() {
match bind_addr {
BindAddr::Address(ip) => {
if ip.is_unspecified() {
if ip.is_ipv6() {
ipv6 = true;
} else {
ipv4 = true;
}
} else {
bind_addrs.push(bind_addr.clone());
}
}
} else {
addrs.push(*ip);
_ => bind_addrs.push(bind_addr.clone()),
}
}
if ipv4 || ipv6 {
@ -156,25 +184,27 @@ fn print_listening(args: Arc<Args>) -> BoxResult<()> {
for iface in ifaces.into_iter() {
let local_ip = iface.ip();
if ipv4 && local_ip.is_ipv4() {
addrs.push(local_ip)
bind_addrs.push(BindAddr::Address(local_ip))
}
if ipv6 && local_ip.is_ipv6() {
addrs.push(local_ip)
bind_addrs.push(BindAddr::Address(local_ip))
}
}
}
addrs.sort_unstable();
let urls = addrs
bind_addrs.sort_unstable();
let urls = bind_addrs
.into_iter()
.map(|addr| match addr {
IpAddr::V4(_) => format!("{}:{}", addr, args.port),
IpAddr::V6(_) => format!("[{}]:{}", addr, args.port),
.map(|bind_addr| match bind_addr {
BindAddr::Address(addr) => {
let addr = match addr {
IpAddr::V4(_) => format!("{}:{}", addr, args.port),
IpAddr::V6(_) => format!("[{}]:{}", addr, args.port),
};
let protocol = if args.tls.is_some() { "https" } else { "http" };
format!("{}://{}{}", protocol, addr, args.uri_prefix)
}
BindAddr::Path(path) => path.display().to_string(),
})
.map(|addr| match &args.tls {
Some(_) => format!("https://{}", addr),
None => format!("http://{}", addr),
})
.map(|url| format!("{}{}", url, args.uri_prefix))
.collect::<Vec<_>>();
if urls.len() == 1 {

View file

@ -84,13 +84,15 @@ impl Server {
pub async fn call(
self: Arc<Self>,
req: Request,
addr: SocketAddr,
addr: Option<SocketAddr>,
) -> Result<Response, hyper::Error> {
let uri = req.uri().clone();
let assets_prefix = self.assets_prefix.clone();
let enable_cors = self.args.enable_cors;
let mut http_log_data = self.args.log_http.data(&req, &self.args);
http_log_data.insert("remote_addr".to_string(), addr.ip().to_string());
if let Some(addr) = addr {
http_log_data.insert("remote_addr".to_string(), addr.ip().to_string());
}
let mut res = match self.clone().handle(req).await {
Ok(res) => {

31
src/unix.rs Normal file
View file

@ -0,0 +1,31 @@
use hyper::server::accept::Accept;
use tokio::net::UnixListener;
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct UnixAcceptor {
inner: UnixListener,
}
impl UnixAcceptor {
pub fn from_listener(listener: UnixListener) -> Self {
Self { inner: listener }
}
}
impl Accept for UnixAcceptor {
type Conn = tokio::net::UnixStream;
type Error = std::io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
match self.inner.poll_accept(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok((socket, _addr))) => Poll::Ready(Some(Ok(socket))),
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
}
}
}