improvement: cache actual destination

This commit is contained in:
Timo Kösters 2020-12-06 11:05:51 +01:00
parent 45086b54b3
commit d62f17a91a
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
4 changed files with 74 additions and 35 deletions

View file

@ -76,7 +76,7 @@ impl Database {
} }
/// Load an existing database or create a new one. /// Load an existing database or create a new one.
pub fn load_or_create(config: Config) -> Result<Self> { pub async fn load_or_create(config: Config) -> Result<Self> {
let path = config let path = config
.database_path .database_path
.clone() .clone()
@ -106,7 +106,7 @@ impl Database {
let (admin_sender, admin_receiver) = mpsc::unbounded(); let (admin_sender, admin_receiver) = mpsc::unbounded();
let db = Self { let db = Self {
globals: globals::Globals::load(db.open_tree("global")?, config)?, globals: globals::Globals::load(db.open_tree("global")?, config).await?,
users: users::Users { users: users::Users {
userid_password: db.open_tree("userid_password")?, userid_password: db.open_tree("userid_password")?,
userid_displayname: db.open_tree("userid_displayname")?, userid_displayname: db.open_tree("userid_displayname")?,

View file

@ -1,20 +1,25 @@
use crate::{database::Config, utils, Error, Result}; use crate::{database::Config, utils, Error, Result};
use trust_dns_resolver::TokioAsyncResolver;
use std::collections::HashMap;
use log::error; use log::error;
use ruma::ServerName; use ruma::ServerName;
use std::sync::Arc; use std::sync::Arc;
use std::sync::RwLock;
pub const COUNTER: &str = "c"; pub const COUNTER: &str = "c";
#[derive(Clone)] #[derive(Clone)]
pub struct Globals { pub struct Globals {
pub(super) globals: sled::Tree, pub(super) globals: sled::Tree,
config: Config,
keypair: Arc<ruma::signatures::Ed25519KeyPair>, keypair: Arc<ruma::signatures::Ed25519KeyPair>,
reqwest_client: reqwest::Client, reqwest_client: reqwest::Client,
config: Config, pub actual_destination_cache: Arc<RwLock<HashMap<Box<ServerName>, (String, Option<String>)>>>, // actual_destination, host
dns_resolver: TokioAsyncResolver,
} }
impl Globals { impl Globals {
pub fn load(globals: sled::Tree, config: Config) -> Result<Self> { pub async fn load(globals: sled::Tree, config: Config) -> Result<Self> {
let bytes = &*globals let bytes = &*globals
.update_and_fetch("keypair", utils::generate_keypair)? .update_and_fetch("keypair", utils::generate_keypair)?
.expect("utils::generate_keypair always returns Some"); .expect("utils::generate_keypair always returns Some");
@ -51,9 +56,13 @@ impl Globals {
Ok(Self { Ok(Self {
globals, globals,
config,
keypair: Arc::new(keypair), keypair: Arc::new(keypair),
reqwest_client: reqwest::Client::new(), reqwest_client: reqwest::Client::new(),
config, dns_resolver: TokioAsyncResolver::tokio_from_system_conf().await.map_err(|_| {
Error::bad_config("Failed to set up trust dns resolver with system config.")
})?,
actual_destination_cache: Arc::new(RwLock::new(HashMap::new())),
}) })
} }
@ -103,4 +112,8 @@ impl Globals {
pub fn federation_enabled(&self) -> bool { pub fn federation_enabled(&self) -> bool {
self.config.federation_enabled self.config.federation_enabled
} }
pub fn dns_resolver(&self) -> &TokioAsyncResolver {
&self.dns_resolver
}
} }

View file

@ -136,6 +136,7 @@ fn setup_rocket() -> rocket::Rocket {
.attach(AdHoc::on_attach("Config", |rocket| async { .attach(AdHoc::on_attach("Config", |rocket| async {
let data = let data =
Database::load_or_create(rocket.figment().extract().expect("config is valid")) Database::load_or_create(rocket.figment().extract().expect("config is valid"))
.await
.expect("config is valid"); .expect("config is valid");
data.sending.start_handler(&data.globals, &data.rooms); data.sending.start_handler(&data.globals, &data.rooms);

View file

@ -30,7 +30,6 @@ use std::{
sync::Arc, sync::Arc,
time::{Duration, SystemTime}, time::{Duration, SystemTime},
}; };
use trust_dns_resolver::AsyncResolver;
pub async fn request_well_known( pub async fn request_well_known(
globals: &crate::database::globals::Globals, globals: &crate::database::globals::Globals,
@ -66,35 +65,25 @@ where
return Err(Error::bad_config("Federation is disabled.")); return Err(Error::bad_config("Federation is disabled."));
} }
let resolver = AsyncResolver::tokio_from_system_conf().await.map_err(|_| { let maybe_result = globals
Error::bad_config("Failed to set up trust dns resolver with system config.") .actual_destination_cache
})?; .read()
.unwrap()
.get(&destination)
.cloned();
let mut host = None; let (actual_destination, host) = if let Some(result) = maybe_result {
println!("Loaded {} -> {:?}", destination, result);
let actual_destination = "https://".to_owned() result
+ &if let Some(mut delegated_hostname) =
request_well_known(globals, &destination.as_str()).await
{
if let Ok(Some(srv)) = resolver
.srv_lookup(format!("_matrix._tcp.{}", delegated_hostname))
.await
.map(|srv| srv.iter().next().map(|result| result.target().to_string()))
{
host = Some(delegated_hostname);
srv.trim_end_matches('.').to_owned()
} else { } else {
if delegated_hostname.find(':').is_none() { let result = find_actual_destination(globals, &destination).await;
delegated_hostname += ":8448"; globals
} .actual_destination_cache
delegated_hostname .write()
} .unwrap()
} else { .insert(destination.clone(), result.clone());
let mut destination = destination.as_str().to_owned(); println!("Saving {} -> {:?}", destination, result);
if destination.find(':').is_none() { result
destination += ":8448";
}
destination
}; };
let mut http_request = request let mut http_request = request
@ -232,6 +221,42 @@ where
} }
} }
/// Returns: actual_destination, host header
async fn find_actual_destination(
globals: &crate::database::globals::Globals,
destination: &Box<ServerName>,
) -> (String, Option<String>) {
let mut host = None;
let actual_destination = "https://".to_owned()
+ &if let Some(mut delegated_hostname) =
request_well_known(globals, destination.as_str()).await
{
if let Ok(Some(srv)) = globals
.dns_resolver()
.srv_lookup(format!("_matrix._tcp.{}", delegated_hostname))
.await
.map(|srv| srv.iter().next().map(|result| result.target().to_string()))
{
host = Some(delegated_hostname);
srv.trim_end_matches('.').to_owned()
} else {
if delegated_hostname.find(':').is_none() {
delegated_hostname += ":8448";
}
delegated_hostname
}
} else {
let mut destination = destination.as_str().to_owned();
if destination.find(':').is_none() {
destination += ":8448";
}
destination
};
(actual_destination, host)
}
#[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))] #[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))]
pub fn get_server_version_route( pub fn get_server_version_route(
db: State<'_, Database>, db: State<'_, Database>,