From d4ccfa16dc700e60d8ca9b4b19915b9721059e71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sun, 3 Apr 2022 11:48:25 +0200 Subject: [PATCH] async --- Cargo.lock | 22 ++++++ Cargo.toml | 3 +- src/database/abstraction/rocksdb.rs | 4 +- src/database/rooms.rs | 53 ++++++++++++-- src/server_server.rs | 104 +++++++++++++++------------- 5 files changed, 129 insertions(+), 57 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 66daf5e6..948e1507 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -96,6 +96,27 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e" +dependencies = [ + "async-stream-impl", + "futures-core", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.52" @@ -389,6 +410,7 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" name = "conduit" version = "0.3.0-next" dependencies = [ + "async-stream", "axum", "axum-server", "base64 0.13.0", diff --git a/Cargo.toml b/Cargo.toml index 627829f0..d08f0be1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ ruma = { git = "https://github.com/ruma/ruma", rev = "fa2e3662a456bd8957b3e1293c # Async runtime and utilities tokio = { version = "1.11.0", features = ["fs", "macros", "signal", "sync"] } +async-stream = "0.3.2" # Used for storing data permanently sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true } #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] } @@ -76,7 +77,7 @@ crossbeam = { version = "0.8.1", optional = true } num_cpus = "1.13.0" threadpool = "1.8.1" heed = { git = "https://github.com/timokoesters/heed.git", rev = "f6f825da7fb2c758867e05ad973ef800a6fe1d5d", optional = true } -rocksdb = { version = "0.17.0", default-features = false, features = ["multi-threaded-cf", "zstd"], optional = true } +rocksdb = { version = "0.17.0", default-features = true, features = ["multi-threaded-cf", "zstd"], optional = true } thread_local = "1.1.3" # used for TURN server authentication diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index 2cf9d5ee..b8baa513 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -36,8 +36,8 @@ fn db_options(max_open_files: i32, rocksdb_cache: &rocksdb::Cache) -> rocksdb::O db_opts.set_level_compaction_dynamic_level_bytes(true); db_opts.set_target_file_size_base(256 * 1024 * 1024); //db_opts.set_compaction_readahead_size(2 * 1024 * 1024); - //db_opts.set_use_direct_reads(true); - //db_opts.set_use_direct_io_for_flush_and_compaction(true); + db_opts.set_use_direct_reads(true); + db_opts.set_use_direct_io_for_flush_and_compaction(true); db_opts.create_if_missing(true); db_opts.increase_parallelism(num_cpus::get() as i32); db_opts.set_max_open_files(max_open_files); diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 3a71a3b5..82017432 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -1,6 +1,7 @@ mod edus; pub use edus::RoomEdus; +use futures_util::Stream; use crate::{ pdu::{EventHash, PduBuilder}, @@ -39,6 +40,7 @@ use std::{ sync::{Arc, Mutex, RwLock}, }; use tokio::sync::MutexGuard; +use async_stream::try_stream; use tracing::{error, warn}; use super::{abstraction::Tree, pusher}; @@ -1083,6 +1085,38 @@ impl Rooms { .transpose() } + pub async fn get_pdu_async(&self, event_id: &EventId) -> Result>> { + if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { + return Ok(Some(Arc::clone(p))); + } + let eventid_pduid = Arc::clone(&self.eventid_pduid); + let event_id_bytes = event_id.as_bytes().to_vec(); + if let Some(pdu) = tokio::task::spawn_blocking(move || { eventid_pduid .get(&event_id_bytes)}).await.unwrap()? + .map_or_else( + || self.eventid_outlierpdu.get(event_id.as_bytes()), + |pduid| { + Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { + Error::bad_database("Invalid pduid in eventid_pduid.") + })?)) + }, + )? + .map(|pdu| { + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid PDU in db.")) + .map(Arc::new) + }) + .transpose()? + { + self.pdu_cache + .lock() + .unwrap() + .insert(event_id.to_owned(), Arc::clone(&pdu)); + Ok(Some(pdu)) + } else { + Ok(None) + } + } + /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. @@ -2109,7 +2143,7 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, from: u64, - ) -> Result, PduEvent)>> + 'a> { + ) -> Result, PduEvent)>>> { // Create the first part of the full pdu id let prefix = self .get_shortroomid(room_id)? @@ -2124,18 +2158,23 @@ impl Rooms { let user_id = user_id.to_owned(); - Ok(self + let iter = self .pduid_pdu - .iter_from(current, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { + .iter_from(current, false); + + Ok(try_stream! { + while let Some((k, v)) = tokio::task::spawn_blocking(|| { iter.next() }).await.unwrap() { + if !k.starts_with(&prefix) { + return; + } let mut pdu = serde_json::from_slice::(&v) .map_err(|_| Error::bad_database("PDU in db is invalid."))?; if pdu.sender != user_id { pdu.remove_transaction_id()?; } - Ok((pdu_id, pdu)) - })) + yield (k, pdu) + } + }) } /// Replace a PDU with the redacted form. diff --git a/src/server_server.rs b/src/server_server.rs index 9dc26170..044a45c8 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -49,6 +49,7 @@ use ruma::{ }, int, receipt::ReceiptType, + room_id, serde::{Base64, JsonObject, Raw}, signatures::{CanonicalJsonObject, CanonicalJsonValue}, state_res::{self, RoomVersion, StateMap}, @@ -681,7 +682,7 @@ pub async fn send_transaction_message_route( .roomid_mutex_federation .write() .unwrap() - .entry(room_id.clone()) + .entry(room_id!("!somewhere:example.org").to_owned()) // only allow one room at a time .or_default(), ); let mutex_lock = mutex.lock().await; @@ -1141,7 +1142,7 @@ fn handle_outlier_pdu<'a>( // Build map of auth events let mut auth_events = HashMap::new(); for id in &incoming_pdu.auth_events { - let auth_event = match db.rooms.get_pdu(id).map_err(|e| e.to_string())? { + let auth_event = match db.rooms.get_pdu_async(id).await.map_err(|e| e.to_string())? { Some(e) => e, None => { warn!("Could not find auth event {}", id); @@ -1182,7 +1183,7 @@ fn handle_outlier_pdu<'a>( && incoming_pdu.prev_events == incoming_pdu.auth_events { db.rooms - .get_pdu(&incoming_pdu.auth_events[0]) + .get_pdu_async(&incoming_pdu.auth_events[0]).await .map_err(|e| e.to_string())? .filter(|maybe_create| **maybe_create == *create_event) } else { @@ -1265,10 +1266,13 @@ async fn upgrade_outlier_to_timeline_pdu( if let Some(Ok(mut state)) = state { warn!("Using cached state"); - let prev_pdu = - db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| { - "Could not find prev event, but we know the state.".to_owned() - })?; + let prev_pdu = db + .rooms + .get_pdu_async(prev_event) + .await + .ok() + .flatten() + .ok_or_else(|| "Could not find prev event, but we know the state.".to_owned())?; if let Some(state_key) = &prev_pdu.state_key { let shortstatekey = db @@ -1288,7 +1292,7 @@ async fn upgrade_outlier_to_timeline_pdu( let mut okay = true; for prev_eventid in &incoming_pdu.prev_events { - let prev_event = if let Ok(Some(pdu)) = db.rooms.get_pdu(prev_eventid) { + let prev_event = if let Ok(Some(pdu)) = db.rooms.get_pdu_async(prev_eventid).await { pdu } else { okay = false; @@ -1337,7 +1341,7 @@ async fn upgrade_outlier_to_timeline_pdu( } auth_chain_sets.push( - get_auth_chain(room_id, starting_events, db) + get_auth_chain(room_id, starting_events, db).await .map_err(|_| "Failed to load auth chain.".to_owned())? .collect(), ); @@ -1350,7 +1354,7 @@ async fn upgrade_outlier_to_timeline_pdu( &fork_states, auth_chain_sets, |id| { - let res = db.rooms.get_pdu(id); + let res = db.rooms.get_pdu_async(id).await; if let Err(e) = &res { error!("LOOK AT ME Failed to fetch event: {}", e); } @@ -1462,28 +1466,33 @@ async fn upgrade_outlier_to_timeline_pdu( && incoming_pdu.prev_events == incoming_pdu.auth_events { db.rooms - .get_pdu(&incoming_pdu.auth_events[0]) + .get_pdu_async(&incoming_pdu.auth_events[0]) + .await .map_err(|e| e.to_string())? .filter(|maybe_create| **maybe_create == *create_event) } else { None }; - let check_result = state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - previous_create.as_ref(), - None::, // TODO: third party invite - |k, s| { - db.rooms - .get_shortstatekey(k, s) - .ok() - .flatten() - .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) - .and_then(|event_id| db.rooms.get_pdu(event_id).ok().flatten()) - }, - ) - .map_err(|_e| "Auth check failed.".to_owned())?; + let check_result = tokio::task::spawn_blocking(move || { + state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + previous_create.as_ref(), + None::, // TODO: third party invite + |k, s| { + db.rooms + .get_shortstatekey(k, s) + .ok() + .flatten() + .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) + .and_then(|event_id| db.rooms.get_pdu(event_id).ok().flatten()) + }, + ) + .map_err(|_e| "Auth check failed.".to_owned()) + }) + .await + .unwrap()?; if !check_result { return Err("Event has failed auth check with state at the event.".into()); @@ -1591,7 +1600,8 @@ async fn upgrade_outlier_to_timeline_pdu( for id in dbg!(&extremities) { match db .rooms - .get_pdu(id) + .get_pdu_async(id) + .await .map_err(|_| "Failed to ask db for pdu.".to_owned())? { Some(leaf_pdu) => { @@ -1664,7 +1674,7 @@ async fn upgrade_outlier_to_timeline_pdu( room_id, state.iter().map(|(_, id)| id.clone()).collect(), db, - ) + ).await .map_err(|_| "Failed to load auth chain.".to_owned())? .collect(), ); @@ -1685,20 +1695,20 @@ async fn upgrade_outlier_to_timeline_pdu( }) .collect(); - let state = match state_res::resolve( - room_version_id, - &fork_states, - auth_chain_sets, - |id| { + let state = match tokio::task::spawn_blocking(move || { + state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { let res = db.rooms.get_pdu(id); if let Err(e) = &res { error!("LOOK AT ME Failed to fetch event: {}", e); } res.ok().flatten() - }, - ) { - Ok(new_state) => new_state, - Err(_) => { + }).ok() + }) + .await + .unwrap() + { + Some(new_state) => new_state, + None => { return Err("State resolution failed, either an event could not be found or deserialization".into()); } }; @@ -1798,7 +1808,7 @@ pub(crate) fn fetch_and_handle_outliers<'a>( // a. Look in the main timeline (pduid_pdu tree) // b. Look at outlier pdu tree // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = db.rooms.get_pdu(id) { + if let Ok(Some(local_pdu)) = db.rooms.get_pdu_async(id).await { trace!("Found {} in db", id); pdus.push((local_pdu, None)); continue; @@ -1815,7 +1825,7 @@ pub(crate) fn fetch_and_handle_outliers<'a>( continue; } - if let Ok(Some(_)) = db.rooms.get_pdu(&next_id) { + if let Ok(Some(_)) = db.rooms.get_pdu_async(&next_id).await { trace!("Found {} in db", id); continue; } @@ -2153,7 +2163,7 @@ fn append_incoming_pdu<'a>( } #[tracing::instrument(skip(starting_events, db))] -pub(crate) fn get_auth_chain<'a>( +pub(crate) async fn get_auth_chain<'a>( room_id: &RoomId, starting_events: Vec>, db: &'a Database, @@ -2194,7 +2204,7 @@ pub(crate) fn get_auth_chain<'a>( chunk_cache.extend(cached.iter().copied()); } else { misses2 += 1; - let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id, db)?); + let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id, db).await?); db.rooms .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; println!( @@ -2230,7 +2240,7 @@ pub(crate) fn get_auth_chain<'a>( } #[tracing::instrument(skip(event_id, db))] -fn get_auth_chain_inner( +async fn get_auth_chain_inner( room_id: &RoomId, event_id: &EventId, db: &Database, @@ -2239,7 +2249,7 @@ fn get_auth_chain_inner( let mut found = HashSet::new(); while let Some(event_id) = todo.pop() { - match db.rooms.get_pdu(&event_id) { + match db.rooms.get_pdu_async(&event_id).await { Ok(Some(pdu)) => { if pdu.room_id != room_id { return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); @@ -2423,7 +2433,7 @@ pub async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db)?; + let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db).await?; Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids @@ -2477,7 +2487,7 @@ pub async fn get_room_state_route( }) .collect(); - let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db)?; + let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?; Ok(get_room_state::v1::Response { auth_chain: auth_chain_ids @@ -2532,7 +2542,7 @@ pub async fn get_room_state_ids_route( .map(|(_, id)| (*id).to_owned()) .collect(); - let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db)?; + let auth_chain_ids = get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?; Ok(get_room_state_ids::v1::Response { auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), @@ -2795,7 +2805,7 @@ async fn create_join_event( room_id, state_ids.iter().map(|(_, id)| id.clone()).collect(), db, - )?; + ).await?; let servers = db .rooms