Use axum-server for direct TLS support

This commit is contained in:
Jonas Platte 2022-01-22 18:38:39 +01:00
parent 5fa9190117
commit c8951a1d9c
No known key found for this signature in database
GPG key ID: 7D261D771D915378
4 changed files with 57 additions and 10 deletions

28
Cargo.lock generated
View file

@ -58,6 +58,12 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "arc-swap"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5d78ce20460b82d3fa150275ed9d55e21064fc7951177baacf86a145c4a4b1f"
[[package]] [[package]]
name = "arrayref" name = "arrayref"
version = "0.3.6" version = "0.3.6"
@ -162,6 +168,26 @@ dependencies = [
"mime", "mime",
] ]
[[package]]
name = "axum-server"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9cfd9dbe28ebde5c0460067ea27c6f3b1d514b699c4e0a5aab0fb63e452a8a8"
dependencies = [
"arc-swap",
"bytes",
"futures-util",
"http",
"http-body",
"hyper",
"pin-project-lite",
"rustls",
"rustls-pemfile",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.12.3" version = "0.12.3"
@ -365,6 +391,7 @@ name = "conduit"
version = "0.3.0" version = "0.3.0"
dependencies = [ dependencies = [
"axum", "axum",
"axum-server",
"base64 0.13.0", "base64 0.13.0",
"bytes", "bytes",
"clap", "clap",
@ -375,7 +402,6 @@ dependencies = [
"heed", "heed",
"hmac", "hmac",
"http", "http",
"hyper",
"image", "image",
"jsonwebtoken", "jsonwebtoken",
"lru-cache", "lru-cache",

View file

@ -15,7 +15,7 @@ edition = "2021"
[dependencies] [dependencies]
# Web framework # Web framework
axum = { version = "0.4.4", features = ["headers"], optional = true } axum = { version = "0.4.4", features = ["headers"], optional = true }
hyper = "0.14.16" axum-server = { version = "0.3.3", features = ["tls-rustls"] }
tower = { version = "0.4.11", features = ["util"] } tower = { version = "0.4.11", features = ["util"] }
tower-http = { version = "0.2.1", features = ["add-extension", "cors", "compression-full", "sensitive-headers", "trace", "util"] } tower-http = { version = "0.2.1", features = ["add-extension", "cors", "compression-full", "sensitive-headers", "trace", "util"] }

View file

@ -17,6 +17,8 @@ pub struct Config {
pub address: IpAddr, pub address: IpAddr,
#[serde(default = "default_port")] #[serde(default = "default_port")]
pub port: u16, pub port: u16,
pub tls: Option<TlsConfig>,
pub server_name: Box<ServerName>, pub server_name: Box<ServerName>,
#[serde(default = "default_database_backend")] #[serde(default = "default_database_backend")]
pub database_backend: String, pub database_backend: String,
@ -69,6 +71,12 @@ pub struct Config {
pub catchall: BTreeMap<String, IgnoredAny>, pub catchall: BTreeMap<String, IgnoredAny>,
} }
#[derive(Clone, Debug, Deserialize)]
pub struct TlsConfig {
pub certs: String,
pub key: String,
}
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
impl Config { impl Config {

View file

@ -7,7 +7,7 @@
#![allow(clippy::suspicious_else_formatting)] #![allow(clippy::suspicious_else_formatting)]
#![deny(clippy::dbg_macro)] #![deny(clippy::dbg_macro)]
use std::{future::Future, net::SocketAddr, sync::Arc, time::Duration}; use std::{future::Future, io, net::SocketAddr, sync::Arc, time::Duration};
use axum::{ use axum::{
extract::{FromRequest, MatchedPath}, extract::{FromRequest, MatchedPath},
@ -15,6 +15,7 @@ use axum::{
routing::{get, on, MethodFilter}, routing::{get, on, MethodFilter},
Router, Router,
}; };
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
use figment::{ use figment::{
providers::{Env, Format, Toml}, providers::{Env, Format, Toml},
Figment, Figment,
@ -117,8 +118,8 @@ async fn main() {
} }
} }
async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> hyper::Result<()> { async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> io::Result<()> {
let listen_addr = SocketAddr::from((config.address, config.port)); let addr = SocketAddr::from((config.address, config.port));
let x_requested_with = HeaderName::from_static("x-requested-with"); let x_requested_with = HeaderName::from_static("x-requested-with");
@ -157,10 +158,20 @@ async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> hyper::Result
) )
.add_extension(db.clone()); .add_extension(db.clone());
axum::Server::bind(&listen_addr) let app = routes().layer(middlewares).into_make_service();
.serve(routes().layer(middlewares).into_make_service()) let handle = ServerHandle::new();
.with_graceful_shutdown(shutdown_signal())
.await?; tokio::spawn(shutdown_signal(handle.clone()));
match &config.tls {
Some(tls) => {
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
bind_rustls(addr, conf).handle(handle).serve(app).await?;
}
None => {
bind(addr).handle(handle).serve(app).await?;
}
}
// After serve exits and before exiting, shutdown the DB // After serve exits and before exiting, shutdown the DB
Database::on_shutdown(db).await; Database::on_shutdown(db).await;
@ -312,7 +323,7 @@ fn routes() -> Router {
.ruma_route(server_server::claim_keys_route) .ruma_route(server_server::claim_keys_route)
} }
async fn shutdown_signal() { async fn shutdown_signal(handle: ServerHandle) {
let ctrl_c = async { let ctrl_c = async {
signal::ctrl_c() signal::ctrl_c()
.await .await
@ -334,6 +345,8 @@ async fn shutdown_signal() {
_ = ctrl_c => {}, _ = ctrl_c => {},
_ = terminate => {}, _ = terminate => {},
} }
handle.graceful_shutdown(Some(Duration::from_secs(30)));
} }
trait RouterExt { trait RouterExt {