From 5049d0e01b173b7fe1f6cdd6b22fa22a4e223d29 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timo=20K=C3=B6sters?= <timo@koesters.xyz>
Date: Tue, 13 Apr 2021 18:17:51 +0200
Subject: [PATCH] improvement: check signatures on join

---
 src/client_server/membership.rs | 100 +++++++++++++--------------
 src/server_server.rs            | 115 ++++++++++++++------------------
 2 files changed, 94 insertions(+), 121 deletions(-)

diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs
index d491ca0d..f6489788 100644
--- a/src/client_server/membership.rs
+++ b/src/client_server/membership.rs
@@ -2,7 +2,7 @@ use super::State;
 use crate::{
     client_server,
     pdu::{PduBuilder, PduEvent},
-    utils, ConduitResult, Database, Error, Result, Ruma,
+    server_server, utils, ConduitResult, Database, Error, Result, Ruma,
 };
 use log::{error, warn};
 use ruma::{
@@ -21,7 +21,7 @@ use ruma::{
     serde::{to_canonical_value, CanonicalJsonObject, Raw},
     EventId, RoomId, RoomVersionId, ServerName, UserId,
 };
-use std::{collections::BTreeMap, convert::TryFrom, sync::Arc};
+use std::{collections::BTreeMap, convert::TryFrom};
 
 #[cfg(feature = "conduit_bin")]
 use rocket::{get, post};
@@ -515,27 +515,6 @@ async fn join_room_by_id_helper(
             )
             .await?;
 
-        let add_event_id = |pdu: &Raw<Pdu>| -> Result<(EventId, CanonicalJsonObject)> {
-            let mut value = serde_json::from_str(pdu.json().get()).map_err(|e| {
-                error!("{:?}: {:?}", pdu, e);
-                Error::BadServerResponse("Invalid PDU in server response")
-            })?;
-            let event_id = EventId::try_from(&*format!(
-                "${}",
-                ruma::signatures::reference_hash(&value, &room_version)
-                    .expect("ruma can calculate reference hashes")
-            ))
-            .expect("ruma's reference hashes are valid event ids");
-
-            value.insert(
-                "event_id".to_owned(),
-                to_canonical_value(&event_id)
-                    .expect("a valid EventId can be converted to CanonicalJsonValue"),
-            );
-
-            Ok((event_id, value))
-        };
-
         let count = db.globals.next_count()?;
 
         let mut pdu_id = room_id.as_bytes().to_vec();
@@ -546,23 +525,15 @@ async fn join_room_by_id_helper(
             .map_err(|_| Error::BadServerResponse("Invalid PDU in send_join response."))?;
 
         let mut state = BTreeMap::new();
+        let mut pub_key_map = BTreeMap::new();
+
+        for pdu in send_join_response.room_state.state.iter() {
+            let (event_id, value) = validate_and_add_event_id(pdu, &room_version, &mut pub_key_map, &db).await?;
+            let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| {
+                warn!("{:?}: {}", value, e);
+                Error::BadServerResponse("Invalid PDU in send_join response.")
+            })?;
 
-        for pdu in send_join_response
-            .room_state
-            .state
-            .iter()
-            .map(add_event_id)
-            .map(|r| {
-                let (event_id, value) = r?;
-                PduEvent::from_id_val(&event_id, value.clone())
-                    .map(|ev| (event_id, Arc::new(ev)))
-                    .map_err(|e| {
-                        warn!("{:?}: {}", value, e);
-                        Error::BadServerResponse("Invalid PDU in send_join response.")
-                    })
-            })
-        {
-            let (_id, pdu) = pdu?;
             db.rooms.add_pdu_outlier(&pdu)?;
             if let Some(state_key) = &pdu.state_key {
                 if pdu.kind == EventType::RoomMember {
@@ -612,22 +583,12 @@ async fn join_room_by_id_helper(
 
         db.rooms.force_state(room_id, state, &db.globals)?;
 
-        for pdu in send_join_response
-            .room_state
-            .auth_chain
-            .iter()
-            .map(add_event_id)
-            .map(|r| {
-                let (event_id, value) = r?;
-                PduEvent::from_id_val(&event_id, value.clone())
-                    .map(|ev| (event_id, Arc::new(ev)))
-                    .map_err(|e| {
-                        warn!("{:?}: {}", value, e);
-                        Error::BadServerResponse("Invalid PDU in send_join response.")
-                    })
-            })
-        {
-            let (_id, pdu) = pdu?;
+        for pdu in send_join_response.room_state.auth_chain.iter() {
+            let (event_id, value) = validate_and_add_event_id(pdu, &room_version, &mut pub_key_map, &db).await?;
+            let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| {
+                warn!("{:?}: {}", value, e);
+                Error::BadServerResponse("Invalid PDU in send_join response.")
+            })?;
             db.rooms.add_pdu_outlier(&pdu)?;
         }
 
@@ -674,3 +635,32 @@ async fn join_room_by_id_helper(
 
     Ok(join_room_by_id::Response::new(room_id.clone()).into())
 }
+
+async fn validate_and_add_event_id(
+    pdu: &Raw<Pdu>,
+    room_version: &RoomVersionId,
+    pub_key_map: &mut BTreeMap<String, BTreeMap<String, String>>,
+    db: &Database,
+) -> Result<(EventId, CanonicalJsonObject)> {
+    let mut value = serde_json::from_str::<CanonicalJsonObject>(pdu.json().get()).map_err(|e| {
+        error!("{:?}: {:?}", pdu, e);
+        Error::BadServerResponse("Invalid PDU in server response")
+    })?;
+
+    server_server::fetch_required_signing_keys(&value, pub_key_map, db).await?;
+
+    let event_id = EventId::try_from(&*format!(
+        "${}",
+        ruma::signatures::reference_hash(&value, &room_version)
+            .expect("ruma can calculate reference hashes")
+    ))
+    .expect("ruma's reference hashes are valid event ids");
+
+    value.insert(
+        "event_id".to_owned(),
+        to_canonical_value(&event_id)
+            .expect("a valid EventId can be converted to CanonicalJsonValue"),
+    );
+
+    Ok((event_id, value))
+}
diff --git a/src/server_server.rs b/src/server_server.rs
index 304bc198..39b626f9 100644
--- a/src/server_server.rs
+++ b/src/server_server.rs
@@ -658,44 +658,7 @@ fn handle_incoming_pdu<'a>(
 
         // We go through all the signatures we see on the value and fetch the corresponding signing
         // keys
-        for (signature_server, signature) in match value
-            .get("signatures")
-            .ok_or_else(|| "No signatures in server response pdu.".to_string())?
-        {
-            CanonicalJsonValue::Object(map) => map,
-            _ => return Err("Invalid signatures object in server response pdu.".to_string()),
-        } {
-            let signature_object = match signature {
-                CanonicalJsonValue::Object(map) => map,
-                _ => {
-                    return Err(
-                        "Invalid signatures content object in server response pdu.".to_string()
-                    )
-                }
-            };
-
-            let signature_ids = signature_object.keys().collect::<Vec<_>>();
-
-            debug!("Fetching signing keys for {}", signature_server);
-            let keys = match fetch_signing_keys(
-                &db,
-                &Box::<ServerName>::try_from(&**signature_server).map_err(|_| {
-                    "Invalid servername in signatures of server response pdu.".to_string()
-                })?,
-                signature_ids,
-            )
-            .await
-            {
-                Ok(keys) => keys,
-                Err(_) => {
-                    return Err(
-                        "Signature verification failed: Could not fetch signing key.".to_string(),
-                    );
-                }
-            };
-
-            pub_key_map.insert(signature_server.clone(), keys);
-        }
+        fetch_required_signing_keys(&value, pub_key_map, db).await.map_err(|e| e.to_string())?;
 
         // 2. Check signatures, otherwise drop
         // 3. check content hash, redact if doesn't match
@@ -1639,38 +1602,58 @@ pub fn get_profile_information_route<'a>(
     .into())
 }
 
-/*
-#[cfg_attr(
-    feature = "conduit_bin",
-    get("/_matrix/federation/v2/invite/<_>/<_>", data = "<body>")
-)]
-pub fn get_user_devices_route<'a>(
-    db: State<'a, Database>,
-    body: Ruma<membership::v1::Request<'_>>,
-) -> ConduitResult<get_profile_information::v1::Response> {
-    if !db.globals.allow_federation() {
-        return Err(Error::bad_config("Federation is disabled."));
-    }
-
-    let mut displayname = None;
-    let mut avatar_url = None;
-
-    match body.field {
-        Some(ProfileField::DisplayName) => displayname = db.users.displayname(&body.user_id)?,
-        Some(ProfileField::AvatarUrl) => avatar_url = db.users.avatar_url(&body.user_id)?,
-        None => {
-            displayname = db.users.displayname(&body.user_id)?;
-            avatar_url = db.users.avatar_url(&body.user_id)?;
+pub async fn fetch_required_signing_keys(
+    event: &BTreeMap<String, CanonicalJsonValue>,
+    pub_key_map: &mut BTreeMap<String, BTreeMap<String, String>>,
+    db: &Database,
+) -> Result<()> {
+    // We go through all the signatures we see on the value and fetch the corresponding signing
+    // keys
+    for (signature_server, signature) in match event
+        .get("signatures")
+        .ok_or_else(|| Error::BadServerResponse("No signatures in server response pdu."))?
+    {
+        CanonicalJsonValue::Object(map) => map,
+        _ => {
+            return Err(Error::BadServerResponse(
+                "Invalid signatures object in server response pdu.",
+            ))
         }
+    } {
+        let signature_object = match signature {
+            CanonicalJsonValue::Object(map) => map,
+            _ => {
+                return Err(Error::BadServerResponse(
+                    "Invalid signatures content object in server response pdu.",
+                ))
+            }
+        };
+
+        let signature_ids = signature_object.keys().collect::<Vec<_>>();
+
+        debug!("Fetching signing keys for {}", signature_server);
+        let keys = match fetch_signing_keys(
+            db,
+            &Box::<ServerName>::try_from(&**signature_server).map_err(|_| {
+                Error::BadServerResponse("Invalid servername in signatures of server response pdu.")
+            })?,
+            signature_ids,
+        )
+        .await
+        {
+            Ok(keys) => keys,
+            Err(_) => {
+                return Err(Error::BadServerResponse(
+                    "Signature verification failed: Could not fetch signing key.",
+                ));
+            }
+        };
+
+        pub_key_map.insert(signature_server.clone(), keys);
     }
 
-    Ok(get_profile_information::v1::Response {
-        displayname,
-        avatar_url,
-    }
-    .into())
+    Ok(())
 }
-*/
 
 #[cfg(test)]
 mod tests {