From fe3b5d32a7bcbcd1097b925a022255e54f24c140 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timo=20K=C3=B6sters?= <timo@koesters.xyz>
Date: Tue, 20 Jul 2021 21:17:15 +0200
Subject: [PATCH] feat: send to-device events over federation

---
 src/client_server/keys.rs      | 20 +++++++++++++++-----
 src/client_server/to_device.rs | 30 +++++++++++++++++++++++++++++-
 src/database/sending.rs        | 32 +++++++++++++++++++++++++-------
 3 files changed, 69 insertions(+), 13 deletions(-)

diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs
index 1ae9f80c..418e41af 100644
--- a/src/client_server/keys.rs
+++ b/src/client_server/keys.rs
@@ -302,6 +302,7 @@ pub async fn get_keys_helper<F: Fn(&UserId) -> bool>(
                 .entry(user_id.server_name())
                 .or_insert_with(Vec::new)
                 .push((user_id, device_ids));
+            continue;
         }
 
         if device_ids.is_empty() {
@@ -364,20 +365,29 @@ pub async fn get_keys_helper<F: Fn(&UserId) -> bool>(
     let mut failures = BTreeMap::new();
 
     for (server, vec) in get_over_federation {
-        let mut device_keys = BTreeMap::new();
+        let mut device_keys_input_fed = BTreeMap::new();
         for (user_id, keys) in vec {
-            device_keys.insert(user_id.clone(), keys.clone());
+            device_keys_input_fed.insert(user_id.clone(), keys.clone());
         }
-        if let Err(_e) = db
+        match db
             .sending
             .send_federation_request(
                 &db.globals,
                 server,
-                federation::keys::get_keys::v1::Request { device_keys },
+                federation::keys::get_keys::v1::Request {
+                    device_keys: device_keys_input_fed,
+                },
             )
             .await
         {
-            failures.insert(server.to_string(), json!({}));
+            Ok(response) => {
+                master_keys.extend(response.master_keys);
+                self_signing_keys.extend(response.self_signing_keys);
+                device_keys.extend(response.device_keys);
+            }
+            Err(_e) => {
+                failures.insert(server.to_string(), json!({}));
+            }
         }
     }
 
diff --git a/src/client_server/to_device.rs b/src/client_server/to_device.rs
index 7896af93..e3fd780c 100644
--- a/src/client_server/to_device.rs
+++ b/src/client_server/to_device.rs
@@ -1,6 +1,12 @@
+use std::collections::BTreeMap;
+
 use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma};
 use ruma::{
-    api::client::{error::ErrorKind, r0::to_device::send_event_to_device},
+    api::{
+        client::{error::ErrorKind, r0::to_device::send_event_to_device},
+        federation::{self, transactions::edu::DirectDeviceContent},
+    },
+    events::EventType,
     to_device::DeviceIdOrAllDevices,
 };
 
@@ -33,6 +39,28 @@ pub async fn send_event_to_device_route(
 
     for (target_user_id, map) in &body.messages {
         for (target_device_id_maybe, event) in map {
+            if target_user_id.server_name() != db.globals.server_name() {
+                let mut map = BTreeMap::new();
+                map.insert(target_device_id_maybe.clone(), event.clone());
+                let mut messages = BTreeMap::new();
+                messages.insert(target_user_id.clone(), map);
+
+                db.sending.send_reliable_edu(
+                    target_user_id.server_name(),
+                    &serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(
+                        DirectDeviceContent {
+                            sender: sender_user.clone(),
+                            ev_type: EventType::from(&body.event_type),
+                            message_id: body.txn_id.clone(),
+                            messages,
+                        },
+                    ))
+                    .expect("DirectToDevice EDU can be serialized"),
+                )?;
+
+                continue;
+            }
+
             match target_device_id_maybe {
                 DeviceIdOrAllDevices::DeviceId(target_device_id) => db.users.add_to_device_event(
                     sender_user,
diff --git a/src/database/sending.rs b/src/database/sending.rs
index 7c9cf644..8dfcbee4 100644
--- a/src/database/sending.rs
+++ b/src/database/sending.rs
@@ -164,9 +164,10 @@ impl Sending {
                                 // Find events that have been added since starting the last request
                                 let new_events = guard.sending.servernamepduids
                                     .scan_prefix(prefix.clone())
-                                    .map(|(k, _)| {
-                                        SendingEventType::Pdu(k[prefix.len()..].to_vec())
+                                    .filter_map(|(k, _)| {
+                                        Self::parse_servercurrentevent(&k).ok()
                                     })
+                                    .map(|(_, event)| event)
                                     .take(30)
                                     .collect::<Vec<_>>();
 
@@ -290,7 +291,14 @@ impl Sending {
 
             if let OutgoingKind::Normal(server_name) = outgoing_kind {
                 if let Ok((select_edus, last_count)) = Self::select_edus(db, server_name) {
-                    events.extend_from_slice(&select_edus);
+                    for edu in &select_edus {
+                        let mut full_key = vec![b'*'];
+                        full_key.extend_from_slice(&edu);
+                        db.sending.servercurrentevents.insert(&full_key, &[])?;
+                    }
+
+                    events.extend(select_edus.into_iter().map(SendingEventType::Edu));
+
                     db.sending
                         .servername_educount
                         .insert(server_name.as_bytes(), &last_count.to_be_bytes())?;
@@ -301,7 +309,7 @@ impl Sending {
         Ok(Some(events))
     }
 
-    pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec<SendingEventType>, u64)> {
+    pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> {
         // u64: count of last edu
         let since = db
             .sending
@@ -366,9 +374,7 @@ impl Sending {
                     }
                 };
 
-                events.push(SendingEventType::Edu(
-                    serde_json::to_vec(&federation_event).expect("json can be serialized"),
-                ));
+                events.push(serde_json::to_vec(&federation_event).expect("json can be serialized"));
 
                 if events.len() >= 20 {
                     break 'outer;
@@ -402,6 +408,18 @@ impl Sending {
         Ok(())
     }
 
+    #[tracing::instrument(skip(self))]
+    pub fn send_reliable_edu(&self, server: &ServerName, serialized: &[u8]) -> Result<()> {
+        let mut key = server.as_bytes().to_vec();
+        key.push(0xff);
+        key.push(b'*');
+        key.extend_from_slice(serialized);
+        self.servernamepduids.insert(&key, b"")?;
+        self.sender.unbounded_send(key).unwrap();
+
+        Ok(())
+    }
+
     #[tracing::instrument(skip(self))]
     pub fn send_pdu_appservice(&self, appservice_id: &str, pdu_id: &[u8]) -> Result<()> {
         let mut key = b"+".to_vec();