From a9ba067e7758207d1411b24c537cf755608632e8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timo=20K=C3=B6sters?= <timo@koesters.xyz>
Date: Sun, 16 Jul 2023 16:50:03 +0200
Subject: [PATCH] fix: e2ee over federation

---
 src/api/client_server/directory.rs |   1 -
 src/api/client_server/keys.rs      |  34 ++++++++--
 src/api/server_server.rs           |  10 +--
 src/database/key_value/users.rs    | 100 +++++++++++++++++------------
 src/service/pusher/mod.rs          |   5 +-
 src/service/users/data.rs          |  16 +++++
 src/service/users/mod.rs           |  34 ++++++++--
 7 files changed, 137 insertions(+), 63 deletions(-)

diff --git a/src/api/client_server/directory.rs b/src/api/client_server/directory.rs
index df1ac40c..a812dbcc 100644
--- a/src/api/client_server/directory.rs
+++ b/src/api/client_server/directory.rs
@@ -20,7 +20,6 @@ use ruma::{
             guest_access::{GuestAccess, RoomGuestAccessEventContent},
             history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
             join_rules::{JoinRule, RoomJoinRulesEventContent},
-            name::RoomNameEventContent,
             topic::RoomTopicEventContent,
         },
         StateEventType,
diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs
index 21f71b6d..3e032211 100644
--- a/src/api/client_server/keys.rs
+++ b/src/api/client_server/keys.rs
@@ -311,15 +311,17 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
             }
         }
 
-        if let Some(master_key) = services()
-            .users
-            .get_master_key(user_id, &allowed_signatures)?
+        if let Some(master_key) =
+            services()
+                .users
+                .get_master_key(sender_user, user_id, &allowed_signatures)?
         {
             master_keys.insert(user_id.to_owned(), master_key);
         }
-        if let Some(self_signing_key) = services()
-            .users
-            .get_self_signing_key(user_id, &allowed_signatures)?
+        if let Some(self_signing_key) =
+            services()
+                .users
+                .get_self_signing_key(sender_user, user_id, &allowed_signatures)?
         {
             self_signing_keys.insert(user_id.to_owned(), self_signing_key);
         }
@@ -357,7 +359,25 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
     while let Some((server, response)) = futures.next().await {
         match response {
             Ok(response) => {
-                master_keys.extend(response.master_keys);
+                for (user, masterkey) in response.master_keys {
+                    let (master_key_id, mut master_key) =
+                        services().users.parse_master_key(&user, &masterkey)?;
+
+                    if let Some(our_master_key) = services().users.get_key(
+                        &master_key_id,
+                        sender_user,
+                        &user,
+                        &allowed_signatures,
+                    )? {
+                        let (_, our_master_key) =
+                            services().users.parse_master_key(&user, &our_master_key)?;
+                        master_key.signatures.extend(our_master_key.signatures);
+                    }
+                    let json = serde_json::to_value(master_key).expect("to_value always works");
+                    let raw = serde_json::from_value(json).expect("Raw::from_value always works");
+                    master_keys.insert(user, raw);
+                }
+
                 self_signing_keys.extend(response.self_signing_keys);
                 device_keys.extend(response.device_keys);
             }
diff --git a/src/api/server_server.rs b/src/api/server_server.rs
index 0177f2ab..2179b16a 100644
--- a/src/api/server_server.rs
+++ b/src/api/server_server.rs
@@ -1806,12 +1806,14 @@ pub async fn get_devices_route(
                 })
             })
             .collect(),
-        master_key: services()
-            .users
-            .get_master_key(&body.user_id, &|u| u.server_name() == sender_servername)?,
+        master_key: services().users.get_master_key(None, &body.user_id, &|u| {
+            u.server_name() == sender_servername
+        })?,
         self_signing_key: services()
             .users
-            .get_self_signing_key(&body.user_id, &|u| u.server_name() == sender_servername)?,
+            .get_self_signing_key(None, &body.user_id, &|u| {
+                u.server_name() == sender_servername
+            })?,
     })
 }
 
diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs
index 359a0724..0301cdaa 100644
--- a/src/database/key_value/users.rs
+++ b/src/database/key_value/users.rs
@@ -451,31 +451,10 @@ impl service::users::Data for KeyValueDatabase {
         user_signing_key: &Option<Raw<CrossSigningKey>>,
     ) -> Result<()> {
         // TODO: Check signatures
-
         let mut prefix = user_id.as_bytes().to_vec();
         prefix.push(0xff);
 
-        // Master key
-        let mut master_key_ids = master_key
-            .deserialize()
-            .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?
-            .keys
-            .into_values();
-
-        let master_key_id = master_key_ids.next().ok_or(Error::BadRequest(
-            ErrorKind::InvalidParam,
-            "Master key contained no key.",
-        ))?;
-
-        if master_key_ids.next().is_some() {
-            return Err(Error::BadRequest(
-                ErrorKind::InvalidParam,
-                "Master key contained more than one key.",
-            ));
-        }
-
-        let mut master_key_key = prefix.clone();
-        master_key_key.extend_from_slice(master_key_id.as_bytes());
+        let (master_key_key, _) = self.parse_master_key(user_id, master_key)?;
 
         self.keyid_key
             .insert(&master_key_key, master_key.json().get().as_bytes())?;
@@ -690,45 +669,80 @@ impl service::users::Data for KeyValueDatabase {
         })
     }
 
+    fn parse_master_key(
+        &self,
+        user_id: &UserId,
+        master_key: &Raw<CrossSigningKey>,
+    ) -> Result<(Vec<u8>, CrossSigningKey)> {
+        let mut prefix = user_id.as_bytes().to_vec();
+        prefix.push(0xff);
+
+        let master_key = master_key
+            .deserialize()
+            .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
+        let mut master_key_ids = master_key.keys.values();
+        let master_key_id = master_key_ids.next().ok_or(Error::BadRequest(
+            ErrorKind::InvalidParam,
+            "Master key contained no key.",
+        ))?;
+        if master_key_ids.next().is_some() {
+            return Err(Error::BadRequest(
+                ErrorKind::InvalidParam,
+                "Master key contained more than one key.",
+            ));
+        }
+        let mut master_key_key = prefix.clone();
+        master_key_key.extend_from_slice(master_key_id.as_bytes());
+        Ok((master_key_key, master_key))
+    }
+
+    fn get_key(
+        &self,
+        key: &[u8],
+        sender_user: Option<&UserId>,
+        user_id: &UserId,
+        allowed_signatures: &dyn Fn(&UserId) -> bool,
+    ) -> Result<Option<Raw<CrossSigningKey>>> {
+        self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
+            let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes)
+                .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?;
+            clean_signatures(
+                &mut cross_signing_key,
+                sender_user,
+                user_id,
+                allowed_signatures,
+            )?;
+
+            Ok(Some(Raw::from_json(
+                serde_json::value::to_raw_value(&cross_signing_key)
+                    .expect("Value to RawValue serialization"),
+            )))
+        })
+    }
+
     fn get_master_key(
         &self,
+        sender_user: Option<&UserId>,
         user_id: &UserId,
         allowed_signatures: &dyn Fn(&UserId) -> bool,
     ) -> Result<Option<Raw<CrossSigningKey>>> {
         self.userid_masterkeyid
             .get(user_id.as_bytes())?
             .map_or(Ok(None), |key| {
-                self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
-                    let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes)
-                        .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?;
-                    clean_signatures(&mut cross_signing_key, user_id, allowed_signatures)?;
-
-                    Ok(Some(Raw::from_json(
-                        serde_json::value::to_raw_value(&cross_signing_key)
-                            .expect("Value to RawValue serialization"),
-                    )))
-                })
+                self.get_key(&key, sender_user, user_id, allowed_signatures)
             })
     }
 
     fn get_self_signing_key(
         &self,
+        sender_user: Option<&UserId>,
         user_id: &UserId,
         allowed_signatures: &dyn Fn(&UserId) -> bool,
     ) -> Result<Option<Raw<CrossSigningKey>>> {
         self.userid_selfsigningkeyid
             .get(user_id.as_bytes())?
             .map_or(Ok(None), |key| {
-                self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
-                    let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes)
-                        .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?;
-                    clean_signatures(&mut cross_signing_key, user_id, allowed_signatures)?;
-
-                    Ok(Some(Raw::from_json(
-                        serde_json::value::to_raw_value(&cross_signing_key)
-                            .expect("Value to RawValue serialization"),
-                    )))
-                })
+                self.get_key(&key, sender_user, user_id, allowed_signatures)
             })
     }
 
@@ -929,6 +943,8 @@ impl service::users::Data for KeyValueDatabase {
     }
 }
 
+impl KeyValueDatabase {}
+
 /// Will only return with Some(username) if the password was not empty and the
 /// username could be successfully parsed.
 /// If utils::string_from_bytes(...) returns an error that username will be skipped
diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs
index 5e4281d2..315c5ef0 100644
--- a/src/service/pusher/mod.rs
+++ b/src/service/pusher/mod.rs
@@ -13,10 +13,7 @@ use ruma::{
         },
         IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
     },
-    events::{
-        room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent},
-        StateEventType, TimelineEventType,
-    },
+    events::{room::power_levels::RoomPowerLevelsEventContent, StateEventType, TimelineEventType},
     push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
     serde::Raw,
     uint, RoomId, UInt, UserId,
diff --git a/src/service/users/data.rs b/src/service/users/data.rs
index 85532109..d01e0702 100644
--- a/src/service/users/data.rs
+++ b/src/service/users/data.rs
@@ -136,14 +136,30 @@ pub trait Data: Send + Sync {
         device_id: &DeviceId,
     ) -> Result<Option<Raw<DeviceKeys>>>;
 
+    fn parse_master_key(
+        &self,
+        user_id: &UserId,
+        master_key: &Raw<CrossSigningKey>,
+    ) -> Result<(Vec<u8>, CrossSigningKey)>;
+
+    fn get_key(
+        &self,
+        key: &[u8],
+        sender_user: Option<&UserId>,
+        user_id: &UserId,
+        allowed_signatures: &dyn Fn(&UserId) -> bool,
+    ) -> Result<Option<Raw<CrossSigningKey>>>;
+
     fn get_master_key(
         &self,
+        sender_user: Option<&UserId>,
         user_id: &UserId,
         allowed_signatures: &dyn Fn(&UserId) -> bool,
     ) -> Result<Option<Raw<CrossSigningKey>>>;
 
     fn get_self_signing_key(
         &self,
+        sender_user: Option<&UserId>,
         user_id: &UserId,
         allowed_signatures: &dyn Fn(&UserId) -> bool,
     ) -> Result<Option<Raw<CrossSigningKey>>>;
diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs
index 6be5c895..2311c308 100644
--- a/src/service/users/mod.rs
+++ b/src/service/users/mod.rs
@@ -226,20 +226,43 @@ impl Service {
         self.db.get_device_keys(user_id, device_id)
     }
 
-    pub fn get_master_key(
+    pub fn parse_master_key(
         &self,
         user_id: &UserId,
+        master_key: &Raw<CrossSigningKey>,
+    ) -> Result<(Vec<u8>, CrossSigningKey)> {
+        self.db.parse_master_key(user_id, master_key)
+    }
+
+    pub fn get_key(
+        &self,
+        key: &[u8],
+        sender_user: Option<&UserId>,
+        user_id: &UserId,
         allowed_signatures: &dyn Fn(&UserId) -> bool,
     ) -> Result<Option<Raw<CrossSigningKey>>> {
-        self.db.get_master_key(user_id, allowed_signatures)
+        self.db
+            .get_key(key, sender_user, user_id, allowed_signatures)
+    }
+
+    pub fn get_master_key(
+        &self,
+        sender_user: Option<&UserId>,
+        user_id: &UserId,
+        allowed_signatures: &dyn Fn(&UserId) -> bool,
+    ) -> Result<Option<Raw<CrossSigningKey>>> {
+        self.db
+            .get_master_key(sender_user, user_id, allowed_signatures)
     }
 
     pub fn get_self_signing_key(
         &self,
+        sender_user: Option<&UserId>,
         user_id: &UserId,
         allowed_signatures: &dyn Fn(&UserId) -> bool,
     ) -> Result<Option<Raw<CrossSigningKey>>> {
-        self.db.get_self_signing_key(user_id, allowed_signatures)
+        self.db
+            .get_self_signing_key(sender_user, user_id, allowed_signatures)
     }
 
     pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
@@ -342,6 +365,7 @@ impl Service {
 /// Ensure that a user only sees signatures from themselves and the target user
 pub fn clean_signatures<F: Fn(&UserId) -> bool>(
     cross_signing_key: &mut serde_json::Value,
+    sender_user: Option<&UserId>,
     user_id: &UserId,
     allowed_signatures: F,
 ) -> Result<(), Error> {
@@ -355,9 +379,9 @@ pub fn clean_signatures<F: Fn(&UserId) -> bool>(
         for (user, signature) in
             mem::replace(signatures, serde_json::Map::with_capacity(new_capacity))
         {
-            let id = <&UserId>::try_from(user.as_str())
+            let sid = <&UserId>::try_from(user.as_str())
                 .map_err(|_| Error::bad_database("Invalid user ID in database."))?;
-            if id == user_id || allowed_signatures(id) {
+            if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) {
                 signatures.insert(user, signature);
             }
         }