diff --git a/Cargo.lock b/Cargo.lock index f84c9829..41105b37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -58,6 +58,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "arc-swap" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5d78ce20460b82d3fa150275ed9d55e21064fc7951177baacf86a145c4a4b1f" + [[package]] name = "arrayref" version = "0.3.6" @@ -162,6 +168,26 @@ dependencies = [ "mime", ] +[[package]] +name = "axum-server" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9cfd9dbe28ebde5c0460067ea27c6f3b1d514b699c4e0a5aab0fb63e452a8a8" +dependencies = [ + "arc-swap", + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "pin-project-lite", + "rustls", + "rustls-pemfile", + "tokio", + "tokio-rustls", + "tower-service", +] + [[package]] name = "base64" version = "0.12.3" @@ -365,6 +391,7 @@ name = "conduit" version = "0.3.0" dependencies = [ "axum", + "axum-server", "base64 0.13.0", "bytes", "clap", @@ -375,7 +402,6 @@ dependencies = [ "heed", "hmac", "http", - "hyper", "image", "jsonwebtoken", "lru-cache", diff --git a/Cargo.toml b/Cargo.toml index 5fb75dcb..6dedfa8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ edition = "2021" [dependencies] # Web framework axum = { version = "0.4.4", features = ["headers"], optional = true } -hyper = "0.14.16" +axum-server = { version = "0.3.3", features = ["tls-rustls"] } tower = { version = "0.4.11", features = ["util"] } tower-http = { version = "0.2.1", features = ["add-extension", "cors", "compression-full", "sensitive-headers", "trace", "util"] } diff --git a/src/config.rs b/src/config.rs index 48ac9816..155704b7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -17,6 +17,8 @@ pub struct Config { pub address: IpAddr, #[serde(default = "default_port")] pub port: u16, + pub tls: Option, + pub server_name: Box, #[serde(default = "default_database_backend")] pub database_backend: String, @@ -69,6 +71,12 @@ pub struct Config { pub catchall: BTreeMap, } +#[derive(Clone, Debug, Deserialize)] +pub struct TlsConfig { + pub certs: String, + pub key: String, +} + const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; impl Config { diff --git a/src/main.rs b/src/main.rs index 40122cf8..22ddf3e0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ #![allow(clippy::suspicious_else_formatting)] #![deny(clippy::dbg_macro)] -use std::{future::Future, net::SocketAddr, sync::Arc, time::Duration}; +use std::{future::Future, io, net::SocketAddr, sync::Arc, time::Duration}; use axum::{ extract::{FromRequest, MatchedPath}, @@ -15,6 +15,7 @@ use axum::{ routing::{get, on, MethodFilter}, Router, }; +use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; use figment::{ providers::{Env, Format, Toml}, Figment, @@ -117,8 +118,8 @@ async fn main() { } } -async fn run_server(config: &Config, db: Arc>) -> hyper::Result<()> { - let listen_addr = SocketAddr::from((config.address, config.port)); +async fn run_server(config: &Config, db: Arc>) -> io::Result<()> { + let addr = SocketAddr::from((config.address, config.port)); let x_requested_with = HeaderName::from_static("x-requested-with"); @@ -157,10 +158,20 @@ async fn run_server(config: &Config, db: Arc>) -> hyper::Result ) .add_extension(db.clone()); - axum::Server::bind(&listen_addr) - .serve(routes().layer(middlewares).into_make_service()) - .with_graceful_shutdown(shutdown_signal()) - .await?; + let app = routes().layer(middlewares).into_make_service(); + let handle = ServerHandle::new(); + + tokio::spawn(shutdown_signal(handle.clone())); + + match &config.tls { + Some(tls) => { + let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; + bind_rustls(addr, conf).handle(handle).serve(app).await?; + } + None => { + bind(addr).handle(handle).serve(app).await?; + } + } // After serve exits and before exiting, shutdown the DB Database::on_shutdown(db).await; @@ -312,7 +323,7 @@ fn routes() -> Router { .ruma_route(server_server::claim_keys_route) } -async fn shutdown_signal() { +async fn shutdown_signal(handle: ServerHandle) { let ctrl_c = async { signal::ctrl_c() .await @@ -334,6 +345,8 @@ async fn shutdown_signal() { _ = ctrl_c => {}, _ = terminate => {}, } + + handle.graceful_shutdown(Some(Duration::from_secs(30))); } trait RouterExt {