From 069338776919c88f6f90fbd0a6be3fa6b26c2de2 Mon Sep 17 00:00:00 2001
From: timokoesters <timo@koesters.xyz>
Date: Wed, 29 Jul 2020 17:03:04 +0200
Subject: [PATCH] improvement: more efficient /sync and only send device
 updates when sharing a room

---
 src/client_server.rs  | 80 +++++++++++++++++++++++++++----------------
 src/database/rooms.rs | 37 ++++++--------------
 src/database/users.rs | 57 +++++++++++++++++++++++-------
 3 files changed, 105 insertions(+), 69 deletions(-)

diff --git a/src/client_server.rs b/src/client_server.rs
index de76eef7..cd61746b 100644
--- a/src/client_server.rs
+++ b/src/client_server.rs
@@ -1,5 +1,5 @@
 use std::{
-    collections::{hash_map, BTreeMap, HashMap},
+    collections::{hash_map, BTreeMap, HashMap, HashSet},
     convert::{TryFrom, TryInto},
     time::{Duration, SystemTime},
 };
@@ -898,7 +898,7 @@ pub fn upload_keys_route(
         // This check is needed to assure that signatures are kept
         if db.users.get_device_keys(sender_id, device_id)?.is_none() {
             db.users
-                .add_device_keys(sender_id, device_id, device_keys, &db.globals)?;
+                .add_device_keys(sender_id, device_id, device_keys, &db.rooms, &db.globals)?;
         }
     }
 
@@ -2518,20 +2518,41 @@ pub async fn sync_events_route(
         .unwrap_or(0);
 
     let mut presence_updates = HashMap::new();
+    let mut device_list_updates = HashSet::new();
 
     for room_id in db.rooms.rooms_joined(&sender_id) {
         let room_id = room_id?;
 
-        let mut pdus = db
+        let mut non_timeline_pdus = db
             .rooms
             .pdus_since(&sender_id, &room_id, since)?
-            .filter_map(|r| r.ok()) // Filter out buggy events
+            .filter_map(|r| r.ok()); // Filter out buggy events
+
+        // Take the last 10 events for the timeline
+        let timeline_pdus = non_timeline_pdus
+            .by_ref()
+            .rev()
+            .take(10)
+            .collect::<Vec<_>>()
+            .into_iter()
+            .rev()
             .collect::<Vec<_>>();
 
+        // They /sync response doesn't always return all messages, so we say the output is
+        // limited unless there are events in non_timeline_pdus
+        //let mut limited = false;
+
+        let mut state_pdus = Vec::new();
+        for pdu in non_timeline_pdus {
+            if pdu.state_key.is_some() {
+                state_pdus.push(pdu);
+            }
+        }
+
         let mut send_member_count = false;
         let mut joined_since_last_sync = false;
         let mut send_notification_counts = false;
-        for pdu in &pdus {
+        for pdu in db.rooms.pdus_since(&sender_id, &room_id, since)?.filter_map(|r| r.ok()) {
             send_notification_counts = true;
             if pdu.kind == EventType::RoomMember {
                 send_member_count = true;
@@ -2544,8 +2565,8 @@ pub async fn sync_events_route(
                     .map_err(|_| Error::bad_database("Invalid PDU in database."))?;
                     if content.membership == ruma::events::room::member::MembershipState::Join {
                         joined_since_last_sync = true;
-                        // Both send_member_count and joined_since_last_sync are set. There's nothing more
-                        // to do
+                        // Both send_member_count and joined_since_last_sync are set. There's
+                        // nothing more to do
                         break;
                     }
                 }
@@ -2574,7 +2595,7 @@ pub async fn sync_events_route(
                         let content = serde_json::from_value::<
                             Raw<ruma::events::room::member::MemberEventContent>,
                         >(pdu.content.clone())
-                        .map_err(|_| Error::bad_database("Invalid member event in database."))?
+                        .expect("Raw::from_value always works")
                         .deserialize()
                         .map_err(|_| Error::bad_database("Invalid member event in database."))?;
 
@@ -2592,7 +2613,7 @@ pub async fn sync_events_route(
                                     .content
                                     .clone(),
                             )
-                            .map_err(|_| Error::bad_database("Invalid member event in database."))?
+                            .expect("Raw::from_value always works")
                             .deserialize()
                             .map_err(|_| {
                                 Error::bad_database("Invalid member event in database.")
@@ -2659,15 +2680,7 @@ pub async fn sync_events_route(
             None
         };
 
-        // They /sync response doesn't always return all messages, so we say the output is
-        // limited unless there are enough events
-        let mut limited = true;
-        pdus = pdus.split_off(pdus.len().checked_sub(10).unwrap_or_else(|| {
-            limited = false;
-            0
-        }));
-
-        let prev_batch = pdus.first().map_or(Ok::<_, Error>(None), |e| {
+        let prev_batch = timeline_pdus.first().map_or(Ok::<_, Error>(None), |e| {
             Ok(Some(
                 db.rooms
                     .get_pdu_count(&e.event_id)?
@@ -2676,7 +2689,7 @@ pub async fn sync_events_route(
             ))
         })?;
 
-        let room_events = pdus
+        let room_events = timeline_pdus
             .into_iter()
             .map(|pdu| pdu.to_sync_room_event())
             .collect::<Vec<_>>();
@@ -2728,7 +2741,7 @@ pub async fn sync_events_route(
                 notification_count,
             },
             timeline: sync_events::Timeline {
-                limited: limited || joined_since_last_sync,
+                limited: false || joined_since_last_sync,
                 prev_batch,
                 events: room_events,
             },
@@ -2751,6 +2764,13 @@ pub async fn sync_events_route(
             joined_rooms.insert(room_id.clone(), joined_room);
         }
 
+        // Look for device list updates in this room
+        device_list_updates.extend(
+            db.users
+                .keys_changed(&room_id, since)
+                .filter_map(|r| r.ok()),
+        );
+
         // Take presence updates from this room
         for (user_id, presence) in
             db.rooms
@@ -2885,14 +2905,7 @@ pub async fn sync_events_route(
                 .collect::<Vec<_>>(),
         },
         device_lists: sync_events::DeviceLists {
-            changed: if since != 0 {
-                db.users
-                    .keys_changed(since)
-                    .filter_map(|u| u.ok())
-                    .collect() // Filter out buggy events
-            } else {
-                Vec::new()
-            },
+            changed: device_list_updates.into_iter().collect(),
             left: Vec::new(), // TODO
         },
         device_one_time_keys_count: Default::default(), // TODO
@@ -3450,6 +3463,7 @@ pub fn upload_signing_keys_route(
             &master_key,
             &body.self_signing_key,
             &body.user_signing_key,
+            &db.rooms,
             &db.globals,
         )?;
     }
@@ -3500,8 +3514,14 @@ pub fn upload_signatures_route(
                         ))?
                         .to_owned(),
                 );
-                db.users
-                    .sign_key(&user_id, &key_id, signature, &sender_id, &db.globals)?;
+                db.users.sign_key(
+                    &user_id,
+                    &key_id,
+                    signature,
+                    &sender_id,
+                    &db.rooms,
+                    &db.globals,
+                )?;
             }
         }
     }
diff --git a/src/database/rooms.rs b/src/database/rooms.rs
index fe5721c7..4cd47a17 100644
--- a/src/database/rooms.rs
+++ b/src/database/rooms.rs
@@ -611,44 +611,29 @@ impl Rooms {
         self.pdus_since(user_id, room_id, 0)
     }
 
-    /// Returns an iterator over all events in a room that happened after the event with id `since`.
+    /// Returns an iterator over all events in a room that happened after the event with id `since`
+    /// in reverse-chronological order.
     pub fn pdus_since(
         &self,
         user_id: &UserId,
         room_id: &RoomId,
         since: u64,
-    ) -> Result<impl Iterator<Item = Result<PduEvent>>> {
-        // Create the first part of the full pdu id
-        let mut pdu_id = room_id.to_string().as_bytes().to_vec();
-        pdu_id.push(0xff);
-        pdu_id.extend_from_slice(&(since).to_be_bytes());
-
-        self.pdus_since_pduid(user_id, room_id, &pdu_id)
-    }
-
-    /// Returns an iterator over all events in a room that happened after the event with id `since`.
-    pub fn pdus_since_pduid(
-        &self,
-        user_id: &UserId,
-        room_id: &RoomId,
-        pdu_id: &[u8],
-    ) -> Result<impl Iterator<Item = Result<PduEvent>>> {
-        // Create the first part of the full pdu id
+    ) -> Result<impl DoubleEndedIterator<Item = Result<PduEvent>>> {
         let mut prefix = room_id.to_string().as_bytes().to_vec();
         prefix.push(0xff);
 
+        // Skip the first pdu if it's exactly at since, because we sent that last time
+        let mut first_pdu_id = prefix.clone();
+        first_pdu_id.extend_from_slice(&(since+1).to_be_bytes());
+
+        let mut last_pdu_id = prefix.clone();
+        last_pdu_id.extend_from_slice(&u64::MAX.to_be_bytes());
+
         let user_id = user_id.clone();
         Ok(self
             .pduid_pdu
-            .range(pdu_id..)
-            // Skip the first pdu if it's exactly at since, because we sent that last time
-            .skip(if self.pduid_pdu.get(pdu_id)?.is_some() {
-                1
-            } else {
-                0
-            })
+            .range(first_pdu_id..last_pdu_id)
             .filter_map(|r| r.ok())
-            .take_while(move |(k, _)| k.starts_with(&prefix))
             .map(move |(_, v)| {
                 let mut pdu = serde_json::from_slice::<PduEvent>(&v)
                     .map_err(|_| Error::bad_database("PDU in db is invalid."))?;
diff --git a/src/database/users.rs b/src/database/users.rs
index 5030f32e..7fbdd806 100644
--- a/src/database/users.rs
+++ b/src/database/users.rs
@@ -9,7 +9,7 @@ use ruma::{
         },
     },
     events::{AnyToDeviceEvent, EventType},
-    DeviceId, Raw, UserId,
+    DeviceId, Raw, UserId, RoomId,
 };
 use std::{collections::BTreeMap, convert::TryFrom, mem, time::SystemTime};
 
@@ -22,7 +22,7 @@ pub struct Users {
     pub(super) token_userdeviceid: sled::Tree,
 
     pub(super) onetimekeyid_onetimekeys: sled::Tree, // OneTimeKeyId = UserId + AlgorithmAndDeviceId
-    pub(super) keychangeid_userid: sled::Tree,       // KeyChangeId = Count
+    pub(super) keychangeid_userid: sled::Tree,       // KeyChangeId = RoomId + Count
     pub(super) keyid_key: sled::Tree,                // KeyId = UserId + KeyId (depends on key type)
     pub(super) userid_masterkeyid: sled::Tree,
     pub(super) userid_selfsigningkeyid: sled::Tree,
@@ -371,6 +371,7 @@ impl Users {
         user_id: &UserId,
         device_id: &DeviceId,
         device_keys: &DeviceKeys,
+        rooms: &super::rooms::Rooms,
         globals: &super::globals::Globals,
     ) -> Result<()> {
         let mut userdeviceid = user_id.to_string().as_bytes().to_vec();
@@ -382,8 +383,15 @@ impl Users {
             &*serde_json::to_string(&device_keys).expect("DeviceKeys::to_string always works"),
         )?;
 
-        self.keychangeid_userid
-            .insert(globals.next_count()?.to_be_bytes(), &*user_id.to_string())?;
+        let count = globals.next_count()?.to_be_bytes();
+        for room_id in rooms.rooms_joined(&user_id) {
+            let mut key = room_id?.to_string().as_bytes().to_vec();
+            key.push(0xff);
+            key.extend_from_slice(&count);
+
+            self.keychangeid_userid
+                .insert(key, &*user_id.to_string())?;
+        }
 
         Ok(())
     }
@@ -394,6 +402,7 @@ impl Users {
         master_key: &CrossSigningKey,
         self_signing_key: &Option<CrossSigningKey>,
         user_signing_key: &Option<CrossSigningKey>,
+        rooms: &super::rooms::Rooms,
         globals: &super::globals::Globals,
     ) -> Result<()> {
         // TODO: Check signatures
@@ -482,8 +491,15 @@ impl Users {
                 .insert(&*user_id.to_string(), user_signing_key_key)?;
         }
 
-        self.keychangeid_userid
-            .insert(globals.next_count()?.to_be_bytes(), &*user_id.to_string())?;
+        let count = globals.next_count()?.to_be_bytes();
+        for room_id in rooms.rooms_joined(&user_id) {
+            let mut key = room_id?.to_string().as_bytes().to_vec();
+            key.push(0xff);
+            key.extend_from_slice(&count);
+
+            self.keychangeid_userid
+                .insert(key, &*user_id.to_string())?;
+        }
 
         Ok(())
     }
@@ -494,6 +510,7 @@ impl Users {
         key_id: &str,
         signature: (String, String),
         sender_id: &UserId,
+        rooms: &super::rooms::Rooms,
         globals: &super::globals::Globals,
     ) -> Result<()> {
         let mut key = target_id.to_string().as_bytes().to_vec();
@@ -525,19 +542,33 @@ impl Users {
                 .expect("CrossSigningKey::to_string always works"),
         )?;
 
-        self.keychangeid_userid
-            .insert(globals.next_count()?.to_be_bytes(), &*target_id.to_string())?;
+        // TODO: Should we notify about this change?
+        let count = globals.next_count()?.to_be_bytes();
+        for room_id in rooms.rooms_joined(&target_id) {
+            let mut key = room_id?.to_string().as_bytes().to_vec();
+            key.push(0xff);
+            key.extend_from_slice(&count);
+
+            self.keychangeid_userid
+                .insert(key, &*target_id.to_string())?;
+        }
 
         Ok(())
     }
 
-    pub fn keys_changed(&self, since: u64) -> impl Iterator<Item = Result<UserId>> {
+    pub fn keys_changed(&self, room_id: &RoomId, since: u64) -> impl Iterator<Item = Result<UserId>> {
+        let mut prefix = room_id.to_string().as_bytes().to_vec();
+        prefix.push(0xff);
+        let mut start = prefix.clone();
+        start.extend_from_slice(&(since + 1).to_be_bytes());
+
         self.keychangeid_userid
-            .range((since + 1).to_be_bytes()..)
-            .values()
-            .map(|bytes| {
+            .range(start..)
+            .filter_map(|r| r.ok())
+            .take_while(move |(k, _)| k.starts_with(&prefix))
+            .map(|(_, bytes)| {
                 Ok(
-                    UserId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| {
+                    UserId::try_from(utils::string_from_bytes(&bytes).map_err(|_| {
                         Error::bad_database(
                             "User ID in devicekeychangeid_userid is invalid unicode.",
                         )