From 8b40e0a85ffec3ad9a712fd5175944158ac46f5d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timo=20K=C3=B6sters?= <timo@koesters.xyz>
Date: Tue, 13 Apr 2021 21:34:31 +0200
Subject: [PATCH] improvement: fetch signing keys in parallel when joining a
 room

---
 Cargo.lock                      | 48 +++++++++----------
 Cargo.toml                      |  4 +-
 src/client_server/membership.rs | 31 +++++++++++--
 src/server_server.rs            | 82 ++++++++++++++++++++++-----------
 4 files changed, 105 insertions(+), 60 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index d3da6fbf..d153c286 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1625,7 +1625,7 @@ dependencies = [
 [[package]]
 name = "ruma"
 version = "0.0.2"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "assign",
  "js_int",
@@ -1645,9 +1645,8 @@ dependencies = [
 [[package]]
 name = "ruma-api"
 version = "0.17.0-alpha.2"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
- "bytes",
  "http",
  "percent-encoding",
  "ruma-api-macros",
@@ -1661,7 +1660,7 @@ dependencies = [
 [[package]]
 name = "ruma-api-macros"
 version = "0.17.0-alpha.2"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "proc-macro-crate",
  "proc-macro2",
@@ -1672,7 +1671,7 @@ dependencies = [
 [[package]]
 name = "ruma-appservice-api"
 version = "0.2.0-alpha.2"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "ruma-api",
  "ruma-common",
@@ -1686,10 +1685,9 @@ dependencies = [
 [[package]]
 name = "ruma-client-api"
 version = "0.10.0-alpha.2"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "assign",
- "bytes",
  "http",
  "js_int",
  "maplit",
@@ -1705,8 +1703,8 @@ dependencies = [
 
 [[package]]
 name = "ruma-common"
-version = "0.3.1"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+version = "0.4.0"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "indexmap",
  "js_int",
@@ -1722,7 +1720,7 @@ dependencies = [
 [[package]]
 name = "ruma-events"
 version = "0.22.0-alpha.2"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "js_int",
  "ruma-common",
@@ -1736,7 +1734,7 @@ dependencies = [
 [[package]]
 name = "ruma-events-macros"
 version = "0.22.0-alpha.2"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "proc-macro-crate",
  "proc-macro2",
@@ -1747,7 +1745,7 @@ dependencies = [
 [[package]]
 name = "ruma-federation-api"
 version = "0.1.0-alpha.1"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "js_int",
  "ruma-api",
@@ -1761,8 +1759,8 @@ dependencies = [
 
 [[package]]
 name = "ruma-identifiers"
-version = "0.18.1"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+version = "0.19.0"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "paste",
  "rand",
@@ -1775,8 +1773,8 @@ dependencies = [
 
 [[package]]
 name = "ruma-identifiers-macros"
-version = "0.18.1"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+version = "0.19.0"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "proc-macro2",
  "quote",
@@ -1786,13 +1784,13 @@ dependencies = [
 
 [[package]]
 name = "ruma-identifiers-validation"
-version = "0.2.2"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+version = "0.2.3"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 
 [[package]]
 name = "ruma-identity-service-api"
 version = "0.0.1"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "ruma-api",
  "ruma-common",
@@ -1805,7 +1803,7 @@ dependencies = [
 [[package]]
 name = "ruma-push-gateway-api"
 version = "0.0.1"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "js_int",
  "ruma-api",
@@ -1820,7 +1818,7 @@ dependencies = [
 [[package]]
 name = "ruma-serde"
 version = "0.3.1"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "form_urlencoded",
  "itoa",
@@ -1833,7 +1831,7 @@ dependencies = [
 [[package]]
 name = "ruma-serde-macros"
 version = "0.3.1"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "proc-macro-crate",
  "proc-macro2",
@@ -1843,8 +1841,8 @@ dependencies = [
 
 [[package]]
 name = "ruma-signatures"
-version = "0.6.0"
-source = "git+https://github.com/ruma/ruma?rev=6394609feb4af5c43b840fab85b824b13cebb156#6394609feb4af5c43b840fab85b824b13cebb156"
+version = "0.7.0"
+source = "git+https://github.com/ruma/ruma?rev=c1693569f15920e408aa6a26b7f3cc7fc6693a63#c1693569f15920e408aa6a26b7f3cc7fc6693a63"
 dependencies = [
  "base64 0.13.0",
  "ring",
@@ -2122,7 +2120,7 @@ checksum = "3015a7d0a5fd5105c91c3710d42f9ccf0abfb287d62206484dcc67f9569a6483"
 [[package]]
 name = "state-res"
 version = "0.1.0"
-source = "git+https://github.com/timokoesters/state-res?rev=94534b8ff3e71b544ae36206abc182321e9d41f1#94534b8ff3e71b544ae36206abc182321e9d41f1"
+source = "git+https://github.com/timokoesters/state-res?rev=84e70c062708213d01281438598e16f13dffeda4#84e70c062708213d01281438598e16f13dffeda4"
 dependencies = [
  "itertools 0.10.0",
  "log",
diff --git a/Cargo.toml b/Cargo.toml
index 84e40d28..9aa9ceee 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -18,12 +18,12 @@ rocket = { git = "https://github.com/SergioBenitez/Rocket.git", rev = "93e62c86e
 #rocket = { git = "https://github.com/timokoesters/Rocket.git", branch = "empty_parameters", default-features = false, features = ["tls"] }
 
 # Used for matrix spec type definitions and helpers
-ruma = { git = "https://github.com/ruma/ruma", rev = "6394609feb4af5c43b840fab85b824b13cebb156", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "unstable-pre-spec", "unstable-exhaustive-types"] }
+ruma = { git = "https://github.com/ruma/ruma", rev = "c1693569f15920e408aa6a26b7f3cc7fc6693a63", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "unstable-pre-spec", "unstable-exhaustive-types"] }
 #ruma = { git = "https://github.com/timokoesters/ruma", rev = "220d5b4a76b3b781f7f8297fbe6b14473b04214b", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "unstable-pre-spec", "unstable-exhaustive-types"] }
 #ruma = { path = "../ruma/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "unstable-pre-spec", "unstable-exhaustive-types"] }
 
 # Used when doing state resolution
-state-res = { git = "https://github.com/timokoesters/state-res", rev = "94534b8ff3e71b544ae36206abc182321e9d41f1", features = ["unstable-pre-spec"] }
+state-res = { git = "https://github.com/timokoesters/state-res", rev = "84e70c062708213d01281438598e16f13dffeda4", features = ["unstable-pre-spec"] }
 #state-res = { path = "../state-res", features = ["unstable-pre-spec"] }
 
 # Used for long polling and federation sender, should be the same as rocket::tokio
diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs
index f6489788..c3484096 100644
--- a/src/client_server/membership.rs
+++ b/src/client_server/membership.rs
@@ -5,6 +5,7 @@ use crate::{
     server_server, utils, ConduitResult, Database, Error, Result, Ruma,
 };
 use log::{error, warn};
+use rocket::futures;
 use ruma::{
     api::{
         client::{
@@ -21,6 +22,7 @@ use ruma::{
     serde::{to_canonical_value, CanonicalJsonObject, Raw},
     EventId, RoomId, RoomVersionId, ServerName, UserId,
 };
+use std::sync::RwLock;
 use std::{collections::BTreeMap, convert::TryFrom};
 
 #[cfg(feature = "conduit_bin")]
@@ -525,10 +527,18 @@ async fn join_room_by_id_helper(
             .map_err(|_| Error::BadServerResponse("Invalid PDU in send_join response."))?;
 
         let mut state = BTreeMap::new();
-        let mut pub_key_map = BTreeMap::new();
+        let mut pub_key_map = RwLock::new(BTreeMap::new());
 
-        for pdu in send_join_response.room_state.state.iter() {
-            let (event_id, value) = validate_and_add_event_id(pdu, &room_version, &mut pub_key_map, &db).await?;
+        for result in futures::future::join_all(
+            send_join_response
+                .room_state
+                .state
+                .iter()
+                .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)),
+        )
+        .await
+        {
+            let (event_id, value) = result?;
             let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| {
                 warn!("{:?}: {}", value, e);
                 Error::BadServerResponse("Invalid PDU in send_join response.")
@@ -584,7 +594,8 @@ async fn join_room_by_id_helper(
         db.rooms.force_state(room_id, state, &db.globals)?;
 
         for pdu in send_join_response.room_state.auth_chain.iter() {
-            let (event_id, value) = validate_and_add_event_id(pdu, &room_version, &mut pub_key_map, &db).await?;
+            let (event_id, value) =
+                validate_and_add_event_id(pdu, &room_version, &mut pub_key_map, &db).await?;
             let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| {
                 warn!("{:?}: {}", value, e);
                 Error::BadServerResponse("Invalid PDU in send_join response.")
@@ -639,7 +650,7 @@ async fn join_room_by_id_helper(
 async fn validate_and_add_event_id(
     pdu: &Raw<Pdu>,
     room_version: &RoomVersionId,
-    pub_key_map: &mut BTreeMap<String, BTreeMap<String, String>>,
+    pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>,
     db: &Database,
 ) -> Result<(EventId, CanonicalJsonObject)> {
     let mut value = serde_json::from_str::<CanonicalJsonObject>(pdu.json().get()).map_err(|e| {
@@ -648,6 +659,16 @@ async fn validate_and_add_event_id(
     })?;
 
     server_server::fetch_required_signing_keys(&value, pub_key_map, db).await?;
+    if let Err(e) = ruma::signatures::verify_event(
+        &*pub_key_map
+            .read()
+            .map_err(|_| Error::bad_database("RwLock is poisoned."))?,
+        &value,
+        room_version,
+    ) {
+        warn!("Event failed verification: {}", e);
+        return Err(Error::BadServerResponse("Event failed verification."));
+    }
 
     let event_id = EventId::try_from(&*format!(
         "${}",
diff --git a/src/server_server.rs b/src/server_server.rs
index 39b626f9..791ec1ca 100644
--- a/src/server_server.rs
+++ b/src/server_server.rs
@@ -38,7 +38,7 @@ use std::{
     net::{IpAddr, SocketAddr},
     pin::Pin,
     result::Result as StdResult,
-    sync::Arc,
+    sync::{Arc, RwLock},
     time::{Duration, SystemTime},
 };
 
@@ -543,7 +543,7 @@ pub async fn send_transaction_message_route<'a>(
 
     let mut resolved_map = BTreeMap::new();
 
-    let mut pub_key_map = BTreeMap::new();
+    let pub_key_map = RwLock::new(BTreeMap::new());
 
     // This is all the auth_events that have been recursively fetched so they don't have to be
     // deserialized over and over again.
@@ -569,7 +569,7 @@ pub async fn send_transaction_message_route<'a>(
             value,
             true,
             &db,
-            &mut pub_key_map,
+            &pub_key_map,
             &mut auth_cache,
         )
         .await
@@ -622,7 +622,7 @@ fn handle_incoming_pdu<'a>(
     value: BTreeMap<String, CanonicalJsonValue>,
     is_timeline_event: bool,
     db: &'a Database,
-    pub_key_map: &'a mut BTreeMap<String, BTreeMap<String, String>>,
+    pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, String>>>,
     auth_cache: &'a mut EventMap<Arc<PduEvent>>,
 ) -> AsyncRecursiveResult<'a, Arc<PduEvent>> {
     Box::pin(async move {
@@ -658,7 +658,9 @@ fn handle_incoming_pdu<'a>(
 
         // We go through all the signatures we see on the value and fetch the corresponding signing
         // keys
-        fetch_required_signing_keys(&value, pub_key_map, db).await.map_err(|e| e.to_string())?;
+        fetch_required_signing_keys(&value, &pub_key_map, db)
+            .await
+            .map_err(|e| e.to_string())?;
 
         // 2. Check signatures, otherwise drop
         // 3. check content hash, redact if doesn't match
@@ -676,7 +678,11 @@ fn handle_incoming_pdu<'a>(
 
         let room_version = create_event_content.room_version;
 
-        let mut val = match ruma::signatures::verify_event(&pub_key_map, &value, &room_version) {
+        let mut val = match ruma::signatures::verify_event(
+            &*pub_key_map.read().map_err(|_| "RwLock is poisoned.")?,
+            &value,
+            &room_version,
+        ) {
             Err(e) => {
                 // Drop
                 error!("{:?}: {}", value, e);
@@ -1106,7 +1112,7 @@ pub(crate) async fn fetch_and_handle_events(
     db: &Database,
     origin: &ServerName,
     events: &[EventId],
-    pub_key_map: &mut BTreeMap<String, BTreeMap<String, String>>,
+    pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>,
     auth_cache: &mut EventMap<Arc<PduEvent>>,
 ) -> Result<Vec<Arc<PduEvent>>> {
     let mut pdus = vec![];
@@ -1256,6 +1262,7 @@ pub(crate) async fn fetch_signing_keys(
         }
     }
 
+    warn!("Failed to find public key for server: {}", origin);
     Err(Error::BadServerResponse(
         "Failed to find public key for server",
     ))
@@ -1486,7 +1493,7 @@ pub fn get_room_state_ids_route<'a>(
     put("/_matrix/federation/v2/invite/<_>/<_>", data = "<body>")
 )]
 #[tracing::instrument(skip(db, body))]
-pub fn create_invite_route<'a>(
+pub async fn create_invite_route<'a>(
     db: State<'a, Database>,
     body: Ruma<create_invite::v2::Request>,
 ) -> ConduitResult<create_invite::v2::Response> {
@@ -1510,6 +1517,20 @@ pub fn create_invite_route<'a>(
     )
     .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?;
 
+    // Generate event id
+    let event_id = EventId::try_from(&*format!(
+        "${}",
+        ruma::signatures::reference_hash(&signed_event, &body.room_version)
+            .expect("ruma can calculate reference hashes")
+    ))
+    .expect("ruma's reference hashes are valid event ids");
+
+    // Add event_id back
+    signed_event.insert(
+        "event_id".to_owned(),
+        to_canonical_value(&event_id).expect("EventId is a valid CanonicalJsonValue"),
+    );
+
     let sender = serde_json::from_value(
         serde_json::to_value(
             signed_event
@@ -1543,24 +1564,26 @@ pub fn create_invite_route<'a>(
     .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?;
 
     event.insert("event_id".to_owned(), "$dummy".into());
-    invite_state.push(
-        serde_json::from_value::<PduEvent>(event.into())
-            .map_err(|e| {
-                warn!("Invalid invite event: {}", e);
-                Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event.")
-            })?
-            .to_stripped_state_event(),
-    );
 
-    db.rooms.update_membership(
-        &body.room_id,
-        &invited_user,
-        MembershipState::Invite,
-        &sender,
-        Some(invite_state),
-        &db.account_data,
-        &db.globals,
-    )?;
+    let pdu = serde_json::from_value::<PduEvent>(event.into()).map_err(|e| {
+        warn!("Invalid invite event: {}", e);
+        Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event.")
+    })?;
+
+    invite_state.push(pdu.to_stripped_state_event());
+
+    // If the room already exists, the remote server will notify us about the join via /send
+    if !db.rooms.exists(&pdu.room_id)? {
+        db.rooms.update_membership(
+            &body.room_id,
+            &invited_user,
+            MembershipState::Invite,
+            &sender,
+            Some(invite_state),
+            &db.account_data,
+            &db.globals,
+        )?;
+    }
 
     Ok(create_invite::v2::Response {
         event: PduEvent::convert_to_outgoing_federation_event(signed_event),
@@ -1604,7 +1627,7 @@ pub fn get_profile_information_route<'a>(
 
 pub async fn fetch_required_signing_keys(
     event: &BTreeMap<String, CanonicalJsonValue>,
-    pub_key_map: &mut BTreeMap<String, BTreeMap<String, String>>,
+    pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>,
     db: &Database,
 ) -> Result<()> {
     // We go through all the signatures we see on the value and fetch the corresponding signing
@@ -1642,14 +1665,17 @@ pub async fn fetch_required_signing_keys(
         .await
         {
             Ok(keys) => keys,
-            Err(_) => {
+            Err(e) => {
                 return Err(Error::BadServerResponse(
                     "Signature verification failed: Could not fetch signing key.",
                 ));
             }
         };
 
-        pub_key_map.insert(signature_server.clone(), keys);
+        pub_key_map
+            .write()
+            .map_err(|_| Error::bad_database("RwLock is poisoned."))?
+            .insert(signature_server.clone(), keys);
     }
 
     Ok(())