158 lines
5.2 KiB
Rust
158 lines
5.2 KiB
Rust
use core::task::{Context, Poll};
|
|
use futures::ready;
|
|
use hyper::server::accept::Accept;
|
|
use hyper::server::conn::{AddrIncoming, AddrStream};
|
|
use rustls::{Certificate, PrivateKey};
|
|
use std::future::Future;
|
|
use std::net::SocketAddr;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::{fs, io};
|
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
|
use tokio_rustls::rustls::ServerConfig;
|
|
|
|
enum State {
|
|
Handshaking(tokio_rustls::Accept<AddrStream>),
|
|
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
|
|
}
|
|
|
|
// tokio_rustls::server::TlsStream doesn't expose constructor methods,
|
|
// so we have to TlsAcceptor::accept and handshake to have access to it
|
|
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
|
|
pub struct TlsStream {
|
|
state: State,
|
|
remote_addr: SocketAddr,
|
|
}
|
|
|
|
impl TlsStream {
|
|
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
|
|
let remote_addr = stream.remote_addr();
|
|
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
|
|
TlsStream {
|
|
state: State::Handshaking(accept),
|
|
remote_addr,
|
|
}
|
|
}
|
|
pub fn remote_addr(&self) -> SocketAddr {
|
|
self.remote_addr
|
|
}
|
|
}
|
|
|
|
impl AsyncRead for TlsStream {
|
|
fn poll_read(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context,
|
|
buf: &mut ReadBuf,
|
|
) -> Poll<io::Result<()>> {
|
|
let pin = self.get_mut();
|
|
match pin.state {
|
|
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
|
|
Ok(mut stream) => {
|
|
let result = Pin::new(&mut stream).poll_read(cx, buf);
|
|
pin.state = State::Streaming(stream);
|
|
result
|
|
}
|
|
Err(err) => Poll::Ready(Err(err)),
|
|
},
|
|
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AsyncWrite for TlsStream {
|
|
fn poll_write(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<io::Result<usize>> {
|
|
let pin = self.get_mut();
|
|
match pin.state {
|
|
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
|
|
Ok(mut stream) => {
|
|
let result = Pin::new(&mut stream).poll_write(cx, buf);
|
|
pin.state = State::Streaming(stream);
|
|
result
|
|
}
|
|
Err(err) => Poll::Ready(Err(err)),
|
|
},
|
|
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
|
|
}
|
|
}
|
|
|
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
match self.state {
|
|
State::Handshaking(_) => Poll::Ready(Ok(())),
|
|
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
|
|
}
|
|
}
|
|
|
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
match self.state {
|
|
State::Handshaking(_) => Poll::Ready(Ok(())),
|
|
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct TlsAcceptor {
|
|
config: Arc<ServerConfig>,
|
|
incoming: AddrIncoming,
|
|
}
|
|
|
|
impl TlsAcceptor {
|
|
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
|
|
TlsAcceptor { config, incoming }
|
|
}
|
|
}
|
|
|
|
impl Accept for TlsAcceptor {
|
|
type Conn = TlsStream;
|
|
type Error = io::Error;
|
|
|
|
fn poll_accept(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
|
|
let pin = self.get_mut();
|
|
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
|
|
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
|
|
Some(Err(e)) => Poll::Ready(Some(Err(e))),
|
|
None => Poll::Ready(None),
|
|
}
|
|
}
|
|
}
|
|
|
|
// Load public certificate from file.
|
|
pub fn load_certs(filename: &str) -> Result<Vec<Certificate>, Box<dyn std::error::Error>> {
|
|
// Open certificate file.
|
|
let certfile = fs::File::open(&filename)
|
|
.map_err(|e| format!("Failed to access `{}`, {}", &filename, e))?;
|
|
let mut reader = io::BufReader::new(certfile);
|
|
|
|
// Load and return certificate.
|
|
let certs = rustls_pemfile::certs(&mut reader).map_err(|_| "Failed to load certificate")?;
|
|
if certs.is_empty() {
|
|
return Err("No supported certificate in file".into());
|
|
}
|
|
Ok(certs.into_iter().map(Certificate).collect())
|
|
}
|
|
|
|
// Load private key from file.
|
|
pub fn load_private_key(filename: &str) -> Result<PrivateKey, Box<dyn std::error::Error>> {
|
|
// Open keyfile.
|
|
let keyfile = fs::File::open(&filename)
|
|
.map_err(|e| format!("Failed to access `{}`, {}", &filename, e))?;
|
|
let mut reader = io::BufReader::new(keyfile);
|
|
|
|
// Load and return a single private key.
|
|
let keys = rustls_pemfile::read_all(&mut reader)
|
|
.map_err(|e| format!("There was a problem with reading private key: {:?}", e))?
|
|
.into_iter()
|
|
.find_map(|item| match item {
|
|
rustls_pemfile::Item::RSAKey(key) | rustls_pemfile::Item::PKCS8Key(key) => Some(key),
|
|
_ => None,
|
|
})
|
|
.ok_or("No supported private key in file")?;
|
|
|
|
Ok(PrivateKey(keys))
|
|
}
|