diff --git a/src/main.rs b/src/main.rs index 6f2b9e2..30eb929 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,10 +14,23 @@ async fn main() { async fn run() -> BoxResult<()> { let args = Args::parse(matches())?; - serve(args).await + tokio::select! { + ret = serve(args) => { + ret + }, + _ = shutdown_signal() => { + Ok(()) + }, + } } fn handle_err(err: Box) -> T { eprintln!("error: {}", err); std::process::exit(1); } + +async fn shutdown_signal() { + tokio::signal::ctrl_c() + .await + .expect("Failed to install CTRL+C signal handler") +} diff --git a/src/server.rs b/src/server.rs index 983a8d5..2bbfd5c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -54,65 +54,55 @@ macro_rules! status { } pub async fn serve(args: Args) -> BoxResult<()> { - match args.tls.as_ref() { - Some(_) => serve_https(args).await, - None => serve_http(args).await, + let args = Arc::new(args); + let socket_addr = args.address()?; + let inner = Arc::new(InnerService::new(args.clone())); + match args.tls.clone() { + Some((certs, key)) => { + let config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, key)?; + let tls_acceptor = TlsAcceptor::from(Arc::new(config)); + let arc_acceptor = Arc::new(tls_acceptor); + let listener = TcpListener::bind(&socket_addr).await?; + let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); + let incoming = + hyper::server::accept::from_stream(incoming.filter_map(|socket| async { + match socket { + Ok(stream) => match arc_acceptor.clone().accept(stream).await { + Ok(val) => Some(Ok::<_, Infallible>(val)), + Err(_) => None, + }, + Err(_) => None, + } + })); + let server = hyper::Server::builder(incoming).serve(make_service_fn(move |_| { + let inner = inner.clone(); + async move { + Ok::<_, Infallible>(service_fn(move |req| { + let inner = inner.clone(); + inner.call(req) + })) + } + })); + print_listening(args.address.as_str(), args.port, &args.uri_prefix, true); + server.await?; + } + None => { + let server = hyper::Server::try_bind(&socket_addr)?.serve(make_service_fn(move |_| { + let inner = inner.clone(); + async move { + Ok::<_, Infallible>(service_fn(move |req| { + let inner = inner.clone(); + inner.call(req) + })) + } + })); + print_listening(args.address.as_str(), args.port, &args.uri_prefix, false); + server.await?; + } } -} - -pub async fn serve_https(args: Args) -> BoxResult<()> { - let args = Arc::new(args); - let socket_addr = args.address()?; - let (certs, key) = args.tls.clone().unwrap(); - let inner = Arc::new(InnerService::new(args.clone())); - let config = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(certs, key)?; - let tls_acceptor = TlsAcceptor::from(Arc::new(config)); - let arc_acceptor = Arc::new(tls_acceptor); - let listener = TcpListener::bind(&socket_addr).await?; - let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); - let incoming = hyper::server::accept::from_stream(incoming.filter_map(|socket| async { - match socket { - Ok(stream) => match arc_acceptor.clone().accept(stream).await { - Ok(val) => Some(Ok::<_, Infallible>(val)), - Err(_) => None, - }, - Err(_) => None, - } - })); - let server = hyper::Server::builder(incoming).serve(make_service_fn(move |_| { - let inner = inner.clone(); - async move { - Ok::<_, Infallible>(service_fn(move |req| { - let inner = inner.clone(); - inner.call(req) - })) - } - })); - print_listening(args.address.as_str(), args.port, &args.uri_prefix, true); - let graceful = server.with_graceful_shutdown(shutdown_signal()); - graceful.await?; - Ok(()) -} - -pub async fn serve_http(args: Args) -> BoxResult<()> { - let args = Arc::new(args); - let socket_addr = args.address()?; - let inner = Arc::new(InnerService::new(args.clone())); - let server = hyper::Server::try_bind(&socket_addr)?.serve(make_service_fn(move |_| { - let inner = inner.clone(); - async move { - Ok::<_, Infallible>(service_fn(move |req| { - let inner = inner.clone(); - inner.call(req) - })) - } - })); - print_listening(args.address.as_str(), args.port, &args.uri_prefix, false); - let graceful = server.with_graceful_shutdown(shutdown_signal()); - graceful.await?; Ok(()) } @@ -1012,9 +1002,3 @@ fn retrieve_listening_addrs(address: &str) -> Vec { } vec![address.to_owned()] } - -async fn shutdown_signal() { - tokio::signal::ctrl_c() - .await - .expect("Failed to install CTRL+C signal handler") -}