From 8557278b9056fa181e0716d87f3d997022b87135 Mon Sep 17 00:00:00 2001
From: timokoesters <timo@koesters.xyz>
Date: Wed, 8 Apr 2020 15:05:00 +0200
Subject: [PATCH] better riot.im support

---
 README.md           | 11 +++++--
 src/data.rs         | 58 +++++++++++++++++++++++++++++------
 src/main.rs         | 74 ++++++++++++++++++++++++++++++---------------
 src/ruma_wrapper.rs |  2 +-
 src/utils.rs        |  5 +--
 5 files changed, 110 insertions(+), 40 deletions(-)

diff --git a/README.md b/README.md
index 4e428525..0502e304 100644
--- a/README.md
+++ b/README.md
@@ -15,5 +15,12 @@ A Matrix Homeserver that's faster than others.
 - [x] Register, login, authentication tokens
 - [x] Create room messages
 - [x] Sync room messages
-- [ ] Join rooms, lookup room ids
-- [ ] Riot web support
\ No newline at end of file
+- [x] Join rooms, lookup room ids
+- [x] Basic Riot web support
+- [ ] Riot room discovery
+- [ ] Riot read receipts
+- [ ] Riot presence
+- [ ] Proper room creation
+- [ ] Riot E2EE
+- [ ] Basic federation
+- [ ] State resolution
diff --git a/src/data.rs b/src/data.rs
index 56878d7d..a9a0ba5c 100644
--- a/src/data.rs
+++ b/src/data.rs
@@ -103,7 +103,11 @@ impl Data {
             .unwrap();
     }
 
-    pub fn room_join(&self, room_id: &RoomId, user_id: &UserId) {
+    pub fn room_join(&self, room_id: &RoomId, user_id: &UserId) -> bool {
+        if !self.room_exists(room_id) {
+            return false;
+        }
+
         self.db.userid_roomids.add(
             user_id.to_string().as_bytes(),
             room_id.to_string().as_bytes().into(),
@@ -112,6 +116,8 @@ impl Data {
             room_id.to_string().as_bytes(),
             user_id.to_string().as_bytes().into(),
         );
+
+        true
     }
 
     pub fn rooms_joined(&self, user_id: &UserId) -> Vec<RoomId> {
@@ -126,6 +132,24 @@ impl Data {
             .collect()
     }
 
+    /// Check if a room exists by looking for PDUs in that room.
+    pub fn room_exists(&self, room_id: &RoomId) -> bool {
+        // Create the first part of the full pdu id
+        let mut prefix = vec![b'd'];
+        prefix.extend_from_slice(room_id.to_string().as_bytes());
+        prefix.push(b'#'); // Add delimiter so we don't find rooms starting with the same id
+
+        if let Some((key, _)) = self.db.pduid_pdus.get_gt(&prefix).unwrap() {
+            if key.starts_with(&prefix) {
+                true
+            } else {
+                false
+            }
+        } else {
+            false
+        }
+    }
+
     pub fn pdu_get(&self, event_id: &EventId) -> Option<RoomV3Pdu> {
         self.db
             .eventid_pduid
@@ -177,6 +201,7 @@ impl Data {
         sender: UserId,
         event_type: EventType,
         content: serde_json::Value,
+        unsigned: Option<serde_json::Map<String, serde_json::Value>>,
         state_key: Option<String>,
     ) -> EventId {
         // prev_events are the leaves of the current graph. This method removes all leaves from the
@@ -210,7 +235,7 @@ impl Data {
             depth: depth.try_into().unwrap(),
             auth_events: Vec::new(),
             redacts: None,
-            unsigned: Default::default(), // TODO
+            unsigned: unsigned.unwrap_or_default(),
             hashes: ruma_federation_api::EventHash {
                 sha256: "aaa".to_owned(),
             },
@@ -263,29 +288,42 @@ impl Data {
 
     /// Returns a vector of all PDUs in a room.
     pub fn pdus_all(&self, room_id: &RoomId) -> Vec<PduEvent> {
-        self.pdus_since(room_id, "".to_owned())
+        self.pdus_since(room_id, 0)
+    }
+
+    pub fn last_pdu_index(&self) -> u64 {
+        let count_key: Vec<u8> = vec![b'n'];
+        utils::u64_from_bytes(
+            &self
+                .db
+                .pduid_pdus
+                .get(&count_key)
+                .unwrap()
+                .unwrap_or_else(|| (&0_u64.to_be_bytes()).into()),
+        )
     }
 
     /// Returns a vector of all events in a room that happened after the event with id `since`.
-    pub fn pdus_since(&self, room_id: &RoomId, since: String) -> Vec<PduEvent> {
+    pub fn pdus_since(&self, room_id: &RoomId, since: u64) -> Vec<PduEvent> {
         let mut pdus = Vec::new();
 
         // Create the first part of the full pdu id
-        let mut pdu_id = vec![b'd'];
-        pdu_id.extend_from_slice(room_id.to_string().as_bytes());
-        pdu_id.push(b'#'); // Add delimiter so we don't find rooms starting with the same id
+        let mut prefix = vec![b'd'];
+        prefix.extend_from_slice(room_id.to_string().as_bytes());
+        prefix.push(b'#'); // Add delimiter so we don't find rooms starting with the same id
 
-        let mut current = pdu_id.clone();
-        current.extend_from_slice(since.as_bytes());
+        let mut current = prefix.clone();
+        current.extend_from_slice(&since.to_be_bytes());
 
         while let Some((key, value)) = self.db.pduid_pdus.get_gt(&current).unwrap() {
-            if key.starts_with(&pdu_id) {
+            if key.starts_with(&prefix) {
                 current = key.to_vec();
                 pdus.push(serde_json::from_slice(&value).expect("pdu in db is valid"));
             } else {
                 break;
             }
         }
+
         pdus
     }
 
diff --git a/src/main.rs b/src/main.rs
index b9bd3520..a2cfc105 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -33,7 +33,7 @@ use ruma_events::EventType;
 use ruma_identifiers::{RoomId, RoomIdOrAliasId, UserId};
 use ruma_wrapper::{MatrixResult, Ruma};
 use serde_json::json;
-use std::{collections::HashMap, convert::TryInto, path::PathBuf};
+use std::{collections::HashMap, convert::TryInto, path::PathBuf, time::Duration};
 
 const GUEST_NAME_LENGTH: usize = 10;
 const DEVICE_ID_LENGTH: usize = 10;
@@ -67,7 +67,8 @@ fn register_route(
             .to_string(),
             status_code: http::StatusCode::UNAUTHORIZED,
         }));
-    }*/
+    }
+    */
 
     // Validate user id
     let user_id: UserId = match (*format!(
@@ -272,17 +273,18 @@ fn create_room_route(
     // TODO: check if room is unique
     let room_id = RoomId::new(data.hostname()).expect("host is valid");
 
-    data.room_join(
-        &room_id,
-        body.user_id.as_ref().expect("user is authenticated"),
-    );
-
     data.pdu_append(
         room_id.clone(),
         body.user_id.clone().expect("user is authenticated"),
         EventType::RoomMessage,
         json!({"msgtype": "m.text", "body": "Hello"}),
         None,
+        None,
+    );
+
+    data.room_join(
+        &room_id,
+        body.user_id.as_ref().expect("user is authenticated"),
     );
 
     MatrixResult(Ok(create_room::Response { room_id }))
@@ -317,13 +319,20 @@ fn join_room_by_id_route(
     body: Ruma<join_room_by_id::Request>,
     _room_id: String,
 ) -> MatrixResult<join_room_by_id::Response> {
-    data.room_join(
+    if data.room_join(
         &body.room_id,
         body.user_id.as_ref().expect("user is authenticated"),
-    );
-    MatrixResult(Ok(join_room_by_id::Response {
-        room_id: body.room_id.clone(),
-    }))
+    ) {
+        MatrixResult(Ok(join_room_by_id::Response {
+            room_id: body.room_id.clone(),
+        }))
+    } else {
+        MatrixResult(Err(Error {
+            kind: ErrorKind::NotFound,
+            message: "Room not found.".to_owned(),
+            status_code: http::StatusCode::NOT_FOUND,
+        }))
+    }
 }
 
 #[post("/_matrix/client/r0/join/<_room_id_or_alias>", data = "<body>")]
@@ -348,13 +357,18 @@ fn join_room_by_id_or_alias_route(
         RoomIdOrAliasId::RoomId(id) => id.clone(),
     };
 
-    data.room_join(
+    if data.room_join(
         &room_id,
         body.user_id.as_ref().expect("user is authenticated"),
-    );
-    MatrixResult(Ok(join_room_by_id_or_alias::Response {
-        room_id: room_id.clone(),
-    }))
+    ) {
+        MatrixResult(Ok(join_room_by_id_or_alias::Response { room_id }))
+    } else {
+        MatrixResult(Err(Error {
+            kind: ErrorKind::NotFound,
+            message: "Room not found.".to_owned(),
+            status_code: http::StatusCode::NOT_FOUND,
+        }))
+    }
 }
 
 #[put(
@@ -368,11 +382,15 @@ fn create_message_event_route(
     _txn_id: String,
     body: Ruma<create_message_event::Request>,
 ) -> MatrixResult<create_message_event::Response> {
+    let mut unsigned = serde_json::Map::new();
+    unsigned.insert("transaction_id".to_owned(), body.txn_id.clone().into());
+
     let event_id = data.pdu_append(
         body.room_id.clone(),
         body.user_id.clone().expect("user is authenticated"),
         body.event_type.clone(),
-        body.json_body,
+        body.json_body.clone(),
+        Some(unsigned),
         None,
     );
     MatrixResult(Ok(create_message_event::Response { event_id }))
@@ -395,6 +413,7 @@ fn create_state_event_for_key_route(
         body.user_id.clone().expect("user is authenticated"),
         body.event_type.clone(),
         body.json_body.clone(),
+        None,
         Some(body.state_key.clone()),
     );
     MatrixResult(Ok(create_state_event_for_key::Response { event_id }))
@@ -416,6 +435,7 @@ fn create_state_event_for_empty_key_route(
         body.user_id.clone().expect("user is authenticated"),
         body.event_type.clone(),
         body.json_body,
+        None,
         Some("".to_owned()),
     );
     MatrixResult(Ok(create_state_event_for_empty_key::Response { event_id }))
@@ -426,17 +446,21 @@ fn sync_route(
     data: State<Data>,
     body: Ruma<sync_events::Request>,
 ) -> MatrixResult<sync_events::Response> {
+    std::thread::sleep(Duration::from_millis(200));
+    let next_batch = data.last_pdu_index().to_string();
+
     let mut joined_rooms = HashMap::new();
     let joined_roomids = data.rooms_joined(body.user_id.as_ref().expect("user is authenticated"));
     for room_id in joined_roomids {
-        let room_events = data
-            .pdus_all(&room_id)
-            .into_iter()
-            .map(|pdu| pdu.to_room_event())
-            .collect();
+        let pdus = if let Some(since) = body.since.clone().and_then(|string| string.parse().ok()) {
+            data.pdus_since(&room_id, since)
+        } else {
+            data.pdus_all(&room_id)
+        };
+        let room_events = pdus.into_iter().map(|pdu| pdu.to_room_event()).collect();
 
         joined_rooms.insert(
-            "!roomid:localhost".try_into().unwrap(),
+            room_id.try_into().unwrap(),
             sync_events::JoinedRoom {
                 account_data: sync_events::AccountData { events: Vec::new() },
                 summary: sync_events::RoomSummary {
@@ -460,7 +484,7 @@ fn sync_route(
     }
 
     MatrixResult(Ok(sync_events::Response {
-        next_batch: String::new(),
+        next_batch,
         rooms: sync_events::Rooms {
             leave: Default::default(),
             join: joined_rooms,
diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs
index e73c4da7..0bdcfae4 100644
--- a/src/ruma_wrapper.rs
+++ b/src/ruma_wrapper.rs
@@ -95,7 +95,7 @@ where
             let http_request = http_request.body(body.clone()).unwrap();
             log::info!("{:?}", http_request);
 
-            match T::Incoming::try_from(dbg!(http_request)) {
+            match T::Incoming::try_from(http_request) {
                 Ok(t) => Success(Ruma {
                     body: t,
                     user_id,
diff --git a/src/utils.rs b/src/utils.rs
index b32b0f68..e08e09fa 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -8,8 +8,9 @@ pub fn millis_since_unix_epoch() -> js_int::UInt {
     (SystemTime::now()
         .duration_since(UNIX_EPOCH)
         .unwrap()
-        .as_millis() as u32)
-        .into()
+        .as_millis() as u64)
+        .try_into()
+        .expect("time millis are <= MAX_SAFE_UINT")
 }
 
 pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> {