From 72eb1972c18307d77c9dbfaddc2cb92bd9f38efa Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timo=20K=C3=B6sters?= <timo@koesters.xyz>
Date: Mon, 26 Jun 2023 12:38:51 +0200
Subject: [PATCH] Add relations endpoints, edits and threads work now

---
 src/api/client_server/mod.rs                 |   2 +
 src/api/client_server/relations.rs           | 150 ++++++++++++++++-
 src/database/key_value/rooms/pdu_metadata.rs |  59 ++++++-
 src/database/mod.rs                          |   4 +-
 src/main.rs                                  |   3 +
 src/service/rooms/pdu_metadata/data.rs       |  11 +-
 src/service/rooms/pdu_metadata/mod.rs        | 161 ++++++++++++++++++-
 src/service/rooms/timeline/mod.rs            |  24 ++-
 8 files changed, 385 insertions(+), 29 deletions(-)

diff --git a/src/api/client_server/mod.rs b/src/api/client_server/mod.rs
index 4a77f236..2ab3a98a 100644
--- a/src/api/client_server/mod.rs
+++ b/src/api/client_server/mod.rs
@@ -16,6 +16,7 @@ mod profile;
 mod push;
 mod read_marker;
 mod redact;
+mod relations;
 mod report;
 mod room;
 mod search;
@@ -49,6 +50,7 @@ pub use profile::*;
 pub use push::*;
 pub use read_marker::*;
 pub use redact::*;
+pub use relations::*;
 pub use report::*;
 pub use room::*;
 pub use search::*;
diff --git a/src/api/client_server/relations.rs b/src/api/client_server/relations.rs
index 4d2af477..a7cea786 100644
--- a/src/api/client_server/relations.rs
+++ b/src/api/client_server/relations.rs
@@ -1,10 +1,146 @@
-use crate::{services, Result, Ruma};
-use std::time::{Duration, SystemTime};
+use ruma::api::client::relations::{
+    get_relating_events, get_relating_events_with_rel_type,
+    get_relating_events_with_rel_type_and_event_type,
+};
 
-/// # `GET /_matrix/client/r0/todo`
-pub async fn get_relating_events_route(
-    body: Ruma<get_turn_server_info::v3::Request>,
-) -> Result<get_turn_server_info::v3::Response> {
+use crate::{service::rooms::timeline::PduCount, services, Result, Ruma};
+
+/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}`
+pub async fn get_relating_events_with_rel_type_and_event_type_route(
+    body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>,
+) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> {
     let sender_user = body.sender_user.as_ref().expect("user is authenticated");
-    todo!();
+
+    let from = match body.from.clone() {
+        Some(from) => PduCount::try_from_string(&from)?,
+        None => match ruma::api::Direction::Backward {
+            // TODO: fix ruma so `body.dir` exists
+            ruma::api::Direction::Forward => PduCount::min(),
+            ruma::api::Direction::Backward => PduCount::max(),
+        },
+    };
+
+    let to = body
+        .to
+        .as_ref()
+        .and_then(|t| PduCount::try_from_string(&t).ok());
+
+    // Use limit or else 10, with maximum 100
+    let limit = body
+        .limit
+        .and_then(|u| u32::try_from(u).ok())
+        .map_or(10_usize, |u| u as usize)
+        .min(100);
+
+    let res = services()
+        .rooms
+        .pdu_metadata
+        .paginate_relations_with_filter(
+            sender_user,
+            &body.room_id,
+            &body.event_id,
+            Some(body.event_type.clone()),
+            Some(body.rel_type.clone()),
+            from,
+            to,
+            limit,
+        )?;
+
+    Ok(
+        get_relating_events_with_rel_type_and_event_type::v1::Response {
+            chunk: res.chunk,
+            next_batch: res.next_batch,
+            prev_batch: res.prev_batch,
+        },
+    )
+}
+
+/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}`
+pub async fn get_relating_events_with_rel_type_route(
+    body: Ruma<get_relating_events_with_rel_type::v1::Request>,
+) -> Result<get_relating_events_with_rel_type::v1::Response> {
+    let sender_user = body.sender_user.as_ref().expect("user is authenticated");
+
+    let from = match body.from.clone() {
+        Some(from) => PduCount::try_from_string(&from)?,
+        None => match ruma::api::Direction::Backward {
+            // TODO: fix ruma so `body.dir` exists
+            ruma::api::Direction::Forward => PduCount::min(),
+            ruma::api::Direction::Backward => PduCount::max(),
+        },
+    };
+
+    let to = body
+        .to
+        .as_ref()
+        .and_then(|t| PduCount::try_from_string(&t).ok());
+
+    // Use limit or else 10, with maximum 100
+    let limit = body
+        .limit
+        .and_then(|u| u32::try_from(u).ok())
+        .map_or(10_usize, |u| u as usize)
+        .min(100);
+
+    let res = services()
+        .rooms
+        .pdu_metadata
+        .paginate_relations_with_filter(
+            sender_user,
+            &body.room_id,
+            &body.event_id,
+            None,
+            Some(body.rel_type.clone()),
+            from,
+            to,
+            limit,
+        )?;
+
+    Ok(get_relating_events_with_rel_type::v1::Response {
+        chunk: res.chunk,
+        next_batch: res.next_batch,
+        prev_batch: res.prev_batch,
+    })
+}
+
+/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}`
+pub async fn get_relating_events_route(
+    body: Ruma<get_relating_events::v1::Request>,
+) -> Result<get_relating_events::v1::Response> {
+    let sender_user = body.sender_user.as_ref().expect("user is authenticated");
+
+    let from = match body.from.clone() {
+        Some(from) => PduCount::try_from_string(&from)?,
+        None => match ruma::api::Direction::Backward {
+            // TODO: fix ruma so `body.dir` exists
+            ruma::api::Direction::Forward => PduCount::min(),
+            ruma::api::Direction::Backward => PduCount::max(),
+        },
+    };
+
+    let to = body
+        .to
+        .as_ref()
+        .and_then(|t| PduCount::try_from_string(&t).ok());
+
+    // Use limit or else 10, with maximum 100
+    let limit = body
+        .limit
+        .and_then(|u| u32::try_from(u).ok())
+        .map_or(10_usize, |u| u as usize)
+        .min(100);
+
+    services()
+        .rooms
+        .pdu_metadata
+        .paginate_relations_with_filter(
+            sender_user,
+            &body.room_id,
+            &body.event_id,
+            None,
+            None,
+            from,
+            to,
+            limit,
+        )
 }
diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs
index 4b3f810d..0641f9d8 100644
--- a/src/database/key_value/rooms/pdu_metadata.rs
+++ b/src/database/key_value/rooms/pdu_metadata.rs
@@ -1,17 +1,64 @@
-use std::sync::Arc;
+use std::{mem, sync::Arc};
 
-use ruma::{EventId, RoomId};
+use ruma::{EventId, RoomId, UserId};
 
-use crate::{database::KeyValueDatabase, service, Result};
+use crate::{
+    database::KeyValueDatabase,
+    service::{self, rooms::timeline::PduCount},
+    services, utils, Error, PduEvent, Result,
+};
 
 impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
     fn add_relation(&self, from: u64, to: u64) -> Result<()> {
-        let mut key = from.to_be_bytes().to_vec();
-        key.extend_from_slice(&to.to_be_bytes());
-        self.fromto_relation.insert(&key, &[])?;
+        let mut key = to.to_be_bytes().to_vec();
+        key.extend_from_slice(&from.to_be_bytes());
+        self.tofrom_relation.insert(&key, &[])?;
         Ok(())
     }
 
+    fn relations_until<'a>(
+        &'a self,
+        user_id: &'a UserId,
+        shortroomid: u64,
+        target: u64,
+        until: PduCount,
+    ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
+        let prefix = target.to_be_bytes().to_vec();
+        let mut current = prefix.clone();
+
+        let count_raw = match until {
+            PduCount::Normal(x) => x - 1,
+            PduCount::Backfilled(x) => {
+                current.extend_from_slice(&0_u64.to_be_bytes());
+                u64::MAX - x - 1
+            }
+        };
+        current.extend_from_slice(&count_raw.to_be_bytes());
+
+        Ok(Box::new(
+            self.tofrom_relation
+                .iter_from(&current, true)
+                .take_while(move |(k, _)| k.starts_with(&prefix))
+                .map(move |(tofrom, _data)| {
+                    let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
+                        .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
+
+                    let mut pduid = shortroomid.to_be_bytes().to_vec();
+                    pduid.extend_from_slice(&from.to_be_bytes());
+
+                    let mut pdu = services()
+                        .rooms
+                        .timeline
+                        .get_pdu_from_id(&pduid)?
+                        .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
+                    if pdu.sender != user_id {
+                        pdu.remove_transaction_id()?;
+                    }
+                    Ok((PduCount::Normal(from), pdu))
+                }),
+        ))
+    }
+
     fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
         for prev in event_ids {
             let mut key = room_id.as_bytes().to_vec();
diff --git a/src/database/mod.rs b/src/database/mod.rs
index b864cebb..5d89d4af 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -131,7 +131,7 @@ pub struct KeyValueDatabase {
     pub(super) softfailedeventids: Arc<dyn KvTree>,
 
     /// ShortEventId + ShortEventId -> ().
-    pub(super) fromto_relation: Arc<dyn KvTree>,
+    pub(super) tofrom_relation: Arc<dyn KvTree>,
     /// RoomId + EventId -> Parent PDU EventId.
     pub(super) referencedevents: Arc<dyn KvTree>,
 
@@ -348,7 +348,7 @@ impl KeyValueDatabase {
             eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
             softfailedeventids: builder.open_tree("softfailedeventids")?,
 
-            fromto_relation: builder.open_tree("fromto_relation")?,
+            tofrom_relation: builder.open_tree("tofrom_relation")?,
             referencedevents: builder.open_tree("referencedevents")?,
             roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
             roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?,
diff --git a/src/main.rs b/src/main.rs
index 20fab912..f9f88f49 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -383,6 +383,9 @@ fn routes() -> Router {
         // .ruma_route(client_server::third_party_route)
         .ruma_route(client_server::upgrade_room_route)
         .ruma_route(client_server::get_threads_route)
+        .ruma_route(client_server::get_relating_events_with_rel_type_and_event_type_route)
+        .ruma_route(client_server::get_relating_events_with_rel_type_route)
+        .ruma_route(client_server::get_relating_events_route)
         .ruma_route(server_server::get_server_version_route)
         .route(
             "/_matrix/key/v2/server",
diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs
index 5577b3e3..6c4cb3ce 100644
--- a/src/service/rooms/pdu_metadata/data.rs
+++ b/src/service/rooms/pdu_metadata/data.rs
@@ -1,10 +1,17 @@
 use std::sync::Arc;
 
-use crate::Result;
-use ruma::{EventId, RoomId};
+use crate::{service::rooms::timeline::PduCount, PduEvent, Result};
+use ruma::{EventId, RoomId, UserId};
 
 pub trait Data: Send + Sync {
     fn add_relation(&self, from: u64, to: u64) -> Result<()>;
+    fn relations_until<'a>(
+        &'a self,
+        user_id: &'a UserId,
+        room_id: u64,
+        target: u64,
+        until: PduCount,
+    ) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>;
     fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()>;
     fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool>;
     fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>;
diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs
index a82b9a6d..9ce74f4d 100644
--- a/src/service/rooms/pdu_metadata/mod.rs
+++ b/src/service/rooms/pdu_metadata/mod.rs
@@ -2,20 +2,169 @@ mod data;
 use std::sync::Arc;
 
 pub use data::Data;
-use ruma::{EventId, RoomId};
+use ruma::{
+    api::client::relations::get_relating_events,
+    events::{relation::RelationType, TimelineEventType},
+    EventId, RoomId, UserId,
+};
+use serde::Deserialize;
 
-use crate::{services, Result};
+use crate::{services, PduEvent, Result};
+
+use super::timeline::PduCount;
 
 pub struct Service {
     pub db: &'static dyn Data,
 }
 
+#[derive(Clone, Debug, Deserialize)]
+struct ExtractRelType {
+    rel_type: RelationType,
+}
+#[derive(Clone, Debug, Deserialize)]
+struct ExtractRelatesToEventId {
+    #[serde(rename = "m.relates_to")]
+    relates_to: ExtractRelType,
+}
+
 impl Service {
     #[tracing::instrument(skip(self, from, to))]
-    pub fn add_relation(&self, from: &EventId, to: &EventId) -> Result<()> {
-        let from = services().rooms.short.get_or_create_shorteventid(from)?;
-        let to = services().rooms.short.get_or_create_shorteventid(to)?;
-        self.db.add_relation(from, to)
+    pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> {
+        match (from, to) {
+            (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t),
+            _ => {
+                // TODO: Relations with backfilled pdus
+
+                Ok(())
+            }
+        }
+    }
+
+    pub fn paginate_relations_with_filter(
+        &self,
+        sender_user: &UserId,
+        room_id: &RoomId,
+        target: &EventId,
+        filter_event_type: Option<TimelineEventType>,
+        filter_rel_type: Option<RelationType>,
+        from: PduCount,
+        to: Option<PduCount>,
+        limit: usize,
+    ) -> Result<get_relating_events::v1::Response> {
+        let next_token;
+
+        //TODO: Fix ruma: match body.dir {
+        match ruma::api::Direction::Backward {
+            ruma::api::Direction::Forward => {
+                let events_after: Vec<_> = services()
+                    .rooms
+                    .pdu_metadata
+                    .relations_until(sender_user, room_id, target, from)? // TODO: should be relations_after
+                    .filter(|r| {
+                        r.as_ref().map_or(true, |(_, pdu)| {
+                            filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t)
+                                && if let Ok(content) =
+                                    serde_json::from_str::<ExtractRelatesToEventId>(
+                                        pdu.content.get(),
+                                    )
+                                {
+                                    filter_rel_type
+                                        .as_ref()
+                                        .map_or(true, |r| &content.relates_to.rel_type == r)
+                                } else {
+                                    false
+                                }
+                        })
+                    })
+                    .take(limit)
+                    .filter_map(|r| r.ok()) // Filter out buggy events
+                    .filter(|(_, pdu)| {
+                        services()
+                            .rooms
+                            .state_accessor
+                            .user_can_see_event(sender_user, &room_id, &pdu.event_id)
+                            .unwrap_or(false)
+                    })
+                    .take_while(|&(k, _)| Some(k) != to) // Stop at `to`
+                    .collect();
+
+                next_token = events_after.last().map(|(count, _)| count).copied();
+
+                let events_after: Vec<_> = events_after
+                    .into_iter()
+                    .rev() // relations are always most recent first
+                    .map(|(_, pdu)| pdu.to_message_like_event())
+                    .collect();
+
+                Ok(get_relating_events::v1::Response {
+                    chunk: events_after,
+                    next_batch: next_token.map(|t| t.stringify()),
+                    prev_batch: Some(from.stringify()),
+                })
+            }
+            ruma::api::Direction::Backward => {
+                let events_before: Vec<_> = services()
+                    .rooms
+                    .pdu_metadata
+                    .relations_until(sender_user, &room_id, target, from)?
+                    .filter(|r| {
+                        r.as_ref().map_or(true, |(_, pdu)| {
+                            filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t)
+                                && if let Ok(content) =
+                                    serde_json::from_str::<ExtractRelatesToEventId>(
+                                        pdu.content.get(),
+                                    )
+                                {
+                                    filter_rel_type
+                                        .as_ref()
+                                        .map_or(true, |r| &content.relates_to.rel_type == r)
+                                } else {
+                                    false
+                                }
+                        })
+                    })
+                    .take(limit)
+                    .filter_map(|r| r.ok()) // Filter out buggy events
+                    .filter(|(_, pdu)| {
+                        services()
+                            .rooms
+                            .state_accessor
+                            .user_can_see_event(sender_user, &room_id, &pdu.event_id)
+                            .unwrap_or(false)
+                    })
+                    .take_while(|&(k, _)| Some(k) != to) // Stop at `to`
+                    .collect();
+
+                next_token = events_before.last().map(|(count, _)| count).copied();
+
+                let events_before: Vec<_> = events_before
+                    .into_iter()
+                    .map(|(_, pdu)| pdu.to_message_like_event())
+                    .collect();
+
+                Ok(get_relating_events::v1::Response {
+                    chunk: events_before,
+                    next_batch: next_token.map(|t| t.stringify()),
+                    prev_batch: Some(from.stringify()),
+                })
+            }
+        }
+    }
+
+    pub fn relations_until<'a>(
+        &'a self,
+        user_id: &'a UserId,
+        room_id: &'a RoomId,
+        target: &'a EventId,
+        until: PduCount,
+    ) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> {
+        let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?;
+        let target = match services().rooms.timeline.get_pdu_count(target)? {
+            Some(PduCount::Normal(c)) => c,
+            // TODO: Support backfilled relations
+            _ => 0, // This will result in an empty iterator
+        };
+        self.db.relations_until(user_id, room_id, target, until)
     }
 
     #[tracing::instrument(skip(self, room_id, event_ids))]
diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs
index 625d3468..2356a00e 100644
--- a/src/service/rooms/timeline/mod.rs
+++ b/src/service/rooms/timeline/mod.rs
@@ -478,10 +478,16 @@ impl Service {
         }
 
         if let Ok(content) = serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) {
-            services()
+            if let Some(related_pducount) = services()
                 .rooms
-                .pdu_metadata
-                .add_relation(&pdu.event_id, &content.relates_to.event_id)?;
+                .timeline
+                .get_pdu_count(&content.relates_to.event_id)?
+            {
+                services()
+                    .rooms
+                    .pdu_metadata
+                    .add_relation(PduCount::Normal(count2), related_pducount)?;
+            }
         }
 
         if let Ok(content) = serde_json::from_str::<ExtractRelatesTo>(pdu.content.get()) {
@@ -489,10 +495,16 @@ impl Service {
                 Relation::Reply { in_reply_to } => {
                     // We need to do it again here, because replies don't have
                     // event_id as a top level field
-                    services()
+                    if let Some(related_pducount) = services()
                         .rooms
-                        .pdu_metadata
-                        .add_relation(&pdu.event_id, &in_reply_to.event_id)?;
+                        .timeline
+                        .get_pdu_count(&in_reply_to.event_id)?
+                    {
+                        services()
+                            .rooms
+                            .pdu_metadata
+                            .add_relation(PduCount::Normal(count2), related_pducount)?;
+                    }
                 }
                 Relation::Thread(thread) => {
                     services()