From d66c9de8c838cc1d7b03012a3798d8950affabf2 Mon Sep 17 00:00:00 2001 From: sigoden Date: Fri, 8 Mar 2024 10:29:12 +0800 Subject: [PATCH] feat: tls handshake timeout (#368) --- src/main.rs | 46 ++++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/main.rs b/src/main.rs index 6d669f6..298de04 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,6 +29,8 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; +use std::time::Duration; +use tokio::time::timeout; use tokio::{net::TcpListener, task::JoinHandle}; #[cfg(feature = "tls")] use tokio_rustls::{rustls::ServerConfig, TlsAcceptor}; @@ -91,12 +93,19 @@ fn serve(args: Args, running: Arc) -> Result>> { config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let config = Arc::new(config); let tls_accepter = TlsAcceptor::from(config); + let handshake_timeout = Duration::from_secs(10); let handle = tokio::spawn(async move { loop { - let (cnx, addr) = listener.accept().await.unwrap(); - let Ok(stream) = tls_accepter.accept(cnx).await else { - warn!("During tls handshake connection from {}", addr); + let Ok((stream, addr)) = listener.accept().await else { + continue; + }; + let Some(stream) = + timeout(handshake_timeout, tls_accepter.accept(stream)) + .await + .ok() + .and_then(|v| v.ok()) + else { continue; }; let stream = TokioIo::new(stream); @@ -113,8 +122,10 @@ fn serve(args: Args, running: Arc) -> Result>> { (None, None) => { let handle = tokio::spawn(async move { loop { - let (cnx, addr) = listener.accept().await.unwrap(); - let stream = TokioIo::new(cnx); + let Ok((stream, addr)) = listener.accept().await else { + continue; + }; + let stream = TokioIo::new(stream); tokio::spawn(handle_stream( server_handle.clone(), stream, @@ -139,8 +150,10 @@ fn serve(args: Args, running: Arc) -> Result>> { .with_context(|| format!("Failed to bind `{}`", path.display()))?; let handle = tokio::spawn(async move { loop { - let (cnx, _) = listener.accept().await.unwrap(); - let stream = TokioIo::new(cnx); + let Ok((stream, _addr)) = listener.accept().await else { + continue; + }; + let stream = TokioIo::new(stream); tokio::spawn(handle_stream(server_handle.clone(), stream, None)); } }); @@ -160,18 +173,15 @@ where let hyper_service = service_fn(move |request: Request| handle.clone().call(request, addr)); - let ret = Builder::new(TokioExecutor::new()) + match Builder::new(TokioExecutor::new()) .serve_connection_with_upgrades(stream, hyper_service) - .await; - - if let Err(err) = ret { - let scope = match addr { - Some(addr) => format!(" from {}", addr), - None => String::new(), - }; - match err.downcast_ref::() { - Some(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {} - _ => warn!("Serving connection{}: {}", scope, err), + .await + { + Ok(()) => {} + Err(_err) => { + // This error only appears when the client doesn't send a request and terminate the connection. + // + // If client sends one request then terminate connection whenever, it doesn't appear. } } }