feat: support binding abstract unix socket (#468)

This commit is contained in:
sigoden 2024-10-23 06:57:45 +08:00 committed by GitHub
parent bb5a5564b4
commit 881a67e1a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 226 additions and 323 deletions

469
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -461,28 +461,30 @@ impl Args {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum BindAddr { pub enum BindAddr {
Address(IpAddr), IpAddr(IpAddr),
Path(PathBuf), #[cfg(unix)]
SocketPath(String),
} }
impl BindAddr { impl BindAddr {
fn parse_addrs(addrs: &[&str]) -> Result<Vec<Self>> { fn parse_addrs(addrs: &[&str]) -> Result<Vec<Self>> {
let mut bind_addrs = vec![]; let mut bind_addrs = vec![];
#[cfg(not(unix))]
let mut invalid_addrs = vec![]; let mut invalid_addrs = vec![];
for addr in addrs { for addr in addrs {
match addr.parse::<IpAddr>() { match addr.parse::<IpAddr>() {
Ok(v) => { Ok(v) => {
bind_addrs.push(BindAddr::Address(v)); bind_addrs.push(BindAddr::IpAddr(v));
} }
Err(_) => { Err(_) => {
if cfg!(unix) { #[cfg(unix)]
bind_addrs.push(BindAddr::Path(PathBuf::from(addr))); bind_addrs.push(BindAddr::SocketPath(addr.to_string()));
} else { #[cfg(not(unix))]
invalid_addrs.push(*addr); invalid_addrs.push(*addr);
} }
} }
} }
} #[cfg(not(unix))]
if !invalid_addrs.is_empty() { if !invalid_addrs.is_empty() {
bail!("Invalid bind address `{}`", invalid_addrs.join(",")); bail!("Invalid bind address `{}`", invalid_addrs.join(","));
} }
@ -710,7 +712,7 @@ hidden: tmp,*.log,*.lock
assert_eq!(args.serve_path, Args::sanitize_path(&tmpdir).unwrap()); assert_eq!(args.serve_path, Args::sanitize_path(&tmpdir).unwrap());
assert_eq!( assert_eq!(
args.addrs, args.addrs,
vec![BindAddr::Address("0.0.0.0".parse().unwrap())] vec![BindAddr::IpAddr("0.0.0.0".parse().unwrap())]
); );
assert_eq!(args.hidden, ["tmp", "*.log", "*.lock"]); assert_eq!(args.hidden, ["tmp", "*.log", "*.lock"]);
assert_eq!(args.port, 3000); assert_eq!(args.port, 3000);
@ -740,8 +742,8 @@ hidden:
assert_eq!( assert_eq!(
args.addrs, args.addrs,
vec![ vec![
BindAddr::Address("127.0.0.1".parse().unwrap()), BindAddr::IpAddr("127.0.0.1".parse().unwrap()),
BindAddr::Address("192.168.8.10".parse().unwrap()) BindAddr::IpAddr("192.168.8.10".parse().unwrap())
] ]
); );
assert_eq!(args.hidden, ["tmp", "*.log", "*.lock"]); assert_eq!(args.hidden, ["tmp", "*.log", "*.lock"]);

View file

@ -78,7 +78,7 @@ fn serve(args: Args, running: Arc<AtomicBool>) -> Result<Vec<JoinHandle<()>>> {
for bind_addr in addrs.iter() { for bind_addr in addrs.iter() {
let server_handle = server_handle.clone(); let server_handle = server_handle.clone();
match bind_addr { match bind_addr {
BindAddr::Address(ip) => { BindAddr::IpAddr(ip) => {
let listener = create_listener(SocketAddr::new(*ip, port)) let listener = create_listener(SocketAddr::new(*ip, port))
.with_context(|| format!("Failed to bind `{ip}:{port}`"))?; .with_context(|| format!("Failed to bind `{ip}:{port}`"))?;
@ -140,14 +140,21 @@ fn serve(args: Args, running: Arc<AtomicBool>) -> Result<Vec<JoinHandle<()>>> {
} }
}; };
} }
BindAddr::Path(path) => {
if path.exists() {
std::fs::remove_file(path)?;
}
#[cfg(unix)] #[cfg(unix)]
BindAddr::SocketPath(path) => {
let socket_path = if path.starts_with("@")
&& cfg!(any(target_os = "linux", target_os = "android"))
{ {
let listener = tokio::net::UnixListener::bind(path) let mut path_buf = path.as_bytes().to_vec();
.with_context(|| format!("Failed to bind `{}`", path.display()))?; path_buf[0] = b'\0';
unsafe { std::ffi::OsStr::from_encoded_bytes_unchecked(&path_buf) }
.to_os_string()
} else {
let _ = std::fs::remove_file(path);
path.into()
};
let listener = tokio::net::UnixListener::bind(socket_path)
.with_context(|| format!("Failed to bind `{}`", path))?;
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
loop { loop {
let Ok((stream, _addr)) = listener.accept().await else { let Ok((stream, _addr)) = listener.accept().await else {
@ -162,7 +169,6 @@ fn serve(args: Args, running: Arc<AtomicBool>) -> Result<Vec<JoinHandle<()>>> {
} }
} }
} }
}
Ok(handles) Ok(handles)
} }
@ -207,7 +213,7 @@ fn check_addrs(args: &Args) -> Result<(Vec<BindAddr>, Vec<BindAddr>)> {
let (ipv4_addrs, ipv6_addrs) = interface_addrs()?; let (ipv4_addrs, ipv6_addrs) = interface_addrs()?;
for bind_addr in args.addrs.iter() { for bind_addr in args.addrs.iter() {
match bind_addr { match bind_addr {
BindAddr::Address(ip) => match &ip { BindAddr::IpAddr(ip) => match &ip {
IpAddr::V4(_) => { IpAddr::V4(_) => {
if !ipv4_addrs.is_empty() { if !ipv4_addrs.is_empty() {
new_addrs.push(bind_addr.clone()); new_addrs.push(bind_addr.clone());
@ -229,6 +235,7 @@ fn check_addrs(args: &Args) -> Result<(Vec<BindAddr>, Vec<BindAddr>)> {
} }
} }
}, },
#[cfg(unix)]
_ => { _ => {
new_addrs.push(bind_addr.clone()); new_addrs.push(bind_addr.clone());
print_addrs.push(bind_addr.clone()) print_addrs.push(bind_addr.clone())
@ -246,10 +253,10 @@ fn interface_addrs() -> Result<(Vec<BindAddr>, Vec<BindAddr>)> {
for iface in ifaces.into_iter() { for iface in ifaces.into_iter() {
let ip = iface.ip(); let ip = iface.ip();
if ip.is_ipv4() { if ip.is_ipv4() {
ipv4_addrs.push(BindAddr::Address(ip)) ipv4_addrs.push(BindAddr::IpAddr(ip))
} }
if ip.is_ipv6() { if ip.is_ipv6() {
ipv6_addrs.push(BindAddr::Address(ip)) ipv6_addrs.push(BindAddr::IpAddr(ip))
} }
} }
Ok((ipv4_addrs, ipv6_addrs)) Ok((ipv4_addrs, ipv6_addrs))
@ -260,7 +267,7 @@ fn print_listening(args: &Args, print_addrs: &[BindAddr]) -> Result<String> {
let urls = print_addrs let urls = print_addrs
.iter() .iter()
.map(|bind_addr| match bind_addr { .map(|bind_addr| match bind_addr {
BindAddr::Address(addr) => { BindAddr::IpAddr(addr) => {
let addr = match addr { let addr = match addr {
IpAddr::V4(_) => format!("{}:{}", addr, args.port), IpAddr::V4(_) => format!("{}:{}", addr, args.port),
IpAddr::V6(_) => format!("[{}]:{}", addr, args.port), IpAddr::V6(_) => format!("[{}]:{}", addr, args.port),
@ -272,7 +279,8 @@ fn print_listening(args: &Args, print_addrs: &[BindAddr]) -> Result<String> {
}; };
format!("{}://{}{}", protocol, addr, args.uri_prefix) format!("{}://{}{}", protocol, addr, args.uri_prefix)
} }
BindAddr::Path(path) => path.display().to_string(), #[cfg(unix)]
BindAddr::SocketPath(path) => path.to_string(),
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();