From 63a7b530bb99bd291567dbf919256f8c35ba96da Mon Sep 17 00:00:00 2001 From: sigoden Date: Mon, 6 Jun 2022 10:52:12 +0800 Subject: [PATCH] feat: support ipv6 (#25) --- src/args.rs | 25 +++++++++++-------------- src/server.rs | 43 +++++++++++++++++++++++++------------------ 2 files changed, 36 insertions(+), 32 deletions(-) diff --git a/src/args.rs b/src/args.rs index 55e70e8..6466d65 100644 --- a/src/args.rs +++ b/src/args.rs @@ -1,7 +1,7 @@ use clap::crate_description; use clap::{Arg, ArgMatches}; use rustls::{Certificate, PrivateKey}; -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; use std::{env, fs, io}; @@ -111,8 +111,7 @@ pub fn matches() -> ArgMatches { #[derive(Debug, Clone, Eq, PartialEq)] pub struct Args { - pub address: String, - pub port: u16, + pub addr: SocketAddr, pub path: PathBuf, pub path_prefix: String, pub uri_prefix: String, @@ -133,8 +132,9 @@ impl Args { /// If a parsing error ocurred, exit the process and print out informative /// error message to user. pub fn parse(matches: ArgMatches) -> BoxResult { - let address = matches.value_of("address").unwrap_or_default().to_owned(); + let ip = matches.value_of("address").unwrap_or_default(); let port = matches.value_of_t::("port")?; + let addr = to_addr(ip, port)?; let path = Args::parse_path(matches.value_of_os("path").unwrap_or_default())?; let path_prefix = matches .value_of("path-prefix") @@ -166,8 +166,7 @@ impl Args { }; Ok(Args { - address, - port, + addr, path, path_prefix, uri_prefix, @@ -197,17 +196,15 @@ impl Args { }) .map_err(|err| format!("Failed to access path `{}`: {}", path.display(), err,).into()) } +} - /// Construct socket address from arguments. - pub fn address(&self) -> BoxResult { - format!("{}:{}", self.address, self.port) - .parse() - .map_err(|_| format!("Invalid bind address `{}:{}`", self.address, self.port).into()) - } +fn to_addr(ip: &str, port: u16) -> BoxResult { + let ip: IpAddr = ip.parse()?; + Ok(SocketAddr::new(ip, port)) } // Load public certificate from file. -pub fn load_certs(filename: &str) -> BoxResult> { +fn load_certs(filename: &str) -> BoxResult> { // Open certificate file. let certfile = fs::File::open(&filename).map_err(|e| format!("Failed to open {}: {}", &filename, e))?; @@ -222,7 +219,7 @@ pub fn load_certs(filename: &str) -> BoxResult> { } // Load private key from file. -pub fn load_private_key(filename: &str) -> BoxResult { +fn load_private_key(filename: &str) -> BoxResult { // Open keyfile. let keyfile = fs::File::open(&filename).map_err(|e| format!("Failed to open {}: {}", &filename, e))?; diff --git a/src/server.rs b/src/server.rs index 28dfd92..5faff80 100644 --- a/src/server.rs +++ b/src/server.rs @@ -25,7 +25,7 @@ use rustls::ServerConfig; use serde::Serialize; use std::convert::Infallible; use std::fs::Metadata; -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::SystemTime; @@ -56,7 +56,6 @@ macro_rules! status { pub async fn serve(args: Args) -> BoxResult<()> { let args = Arc::new(args); - let socket_addr = args.address()?; let inner = Arc::new(InnerService::new(args.clone())); match args.tls.clone() { Some((certs, key)) => { @@ -66,7 +65,7 @@ pub async fn serve(args: Args) -> BoxResult<()> { .with_single_cert(certs, key)?; let tls_acceptor = TlsAcceptor::from(Arc::new(config)); let arc_acceptor = Arc::new(tls_acceptor); - let listener = TcpListener::bind(&socket_addr).await?; + let listener = TcpListener::bind(&args.addr).await?; let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); let incoming = hyper::server::accept::from_stream(incoming.filter_map(|socket| async { @@ -87,11 +86,11 @@ pub async fn serve(args: Args) -> BoxResult<()> { })) } })); - print_listening(args.address.as_str(), args.port, &args.uri_prefix, true); + print_listening(&args.addr, &args.uri_prefix, true); server.await?; } None => { - let server = hyper::Server::try_bind(&socket_addr)?.serve(make_service_fn(move |_| { + let server = hyper::Server::try_bind(&args.addr)?.serve(make_service_fn(move |_| { let inner = inner.clone(); async move { Ok::<_, Infallible>(service_fn(move |req| { @@ -100,7 +99,7 @@ pub async fn serve(args: Args) -> BoxResult<()> { })) } })); - print_listening(args.address.as_str(), args.port, &args.uri_prefix, false); + print_listening(&args.addr, &args.uri_prefix, false); server.await?; } } @@ -974,37 +973,45 @@ fn to_content_range(range: &Range, complete_length: u64) -> Option }) } -fn print_listening(address: &str, port: u16, prefix: &str, tls: bool) { +fn print_listening(addr: &SocketAddr, prefix: &str, tls: bool) { let prefix = encode_uri(prefix.trim_end_matches('/')); - let addrs = retrieve_listening_addrs(address); + let addrs = retrieve_listening_addrs(addr); let protocol = if tls { "https" } else { "http" }; if addrs.len() == 1 { - eprintln!( - "Listening on {}://{}:{}{}", - protocol, addrs[0], port, prefix - ); + eprintln!("Listening on {}://{}{}", protocol, addr, prefix); } else { eprintln!("Listening on:"); for addr in addrs { - eprintln!(" {}://{}:{}{}", protocol, addr, port, prefix); + eprintln!(" {}://{}{}", protocol, addr, prefix); } eprintln!(); } } -fn retrieve_listening_addrs(address: &str) -> Vec { - if address == "0.0.0.0" { +fn retrieve_listening_addrs(addr: &SocketAddr) -> Vec { + let ip = addr.ip(); + let port = addr.port(); + if ip.is_unspecified() { if let Ok(interfaces) = get_if_addrs() { let mut ifaces: Vec = interfaces .into_iter() .map(|v| v.ip()) - .filter(|v| v.is_ipv4()) + .filter(|v| { + if ip.is_ipv4() { + v.is_ipv4() + } else { + v.is_ipv6() + } + }) .collect(); ifaces.sort(); - return ifaces.into_iter().map(|v| v.to_string()).collect(); + return ifaces + .into_iter() + .map(|v| SocketAddr::new(v, port)) + .collect(); } } - vec![address.to_owned()] + vec![addr.to_owned()] } fn encode_uri(v: &str) -> String {