From 265fab843a42d6eaef7a777104a72d101a2e91f1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timo=20K=C3=B6sters?= <timo@koesters.xyz>
Date: Sun, 24 Jan 2021 16:05:52 +0100
Subject: [PATCH] feature: push rule settings

---
 src/client_server/push.rs | 624 +++++++++++++++++++++++++++++++++++++-
 src/main.rs               |   5 +
 2 files changed, 616 insertions(+), 13 deletions(-)

diff --git a/src/client_server/push.rs b/src/client_server/push.rs
index 05ba8d06..667d6677 100644
--- a/src/client_server/push.rs
+++ b/src/client_server/push.rs
@@ -1,16 +1,22 @@
 use super::State;
 use crate::{ConduitResult, Database, Error, Ruma};
-use log::warn;
 use ruma::{
     api::client::{
         error::ErrorKind,
-        r0::push::{get_pushers, get_pushrules_all, set_pushrule, set_pushrule_enabled},
+        r0::push::{
+            delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled,
+            get_pushrules_all, set_pushrule, set_pushrule_actions, set_pushrule_enabled, RuleKind,
+        },
     },
     events::EventType,
+    push::{
+        ConditionalPushRuleInit, ContentPushRule, OverridePushRule, PatternedPushRuleInit,
+        RoomPushRule, SenderPushRule, SimplePushRuleInit, UnderridePushRule,
+    },
 };
 
 #[cfg(feature = "conduit_bin")]
-use rocket::{get, post, put};
+use rocket::{delete, get, post, put};
 
 #[cfg_attr(
     feature = "conduit_bin",
@@ -36,16 +42,201 @@ pub async fn get_pushrules_all_route(
     .into())
 }
 
-#[cfg_attr(feature = "conduit_bin", put(
-    "/_matrix/client/r0/pushrules/<_>/<_>/<_>",
-    //data = "<body>"
-))]
+#[cfg_attr(
+    feature = "conduit_bin",
+    get("/_matrix/client/r0/pushrules/<_>/<_>/<_>", data = "<body>")
+)]
+pub async fn get_pushrule_route(
+    db: State<'_, Database>,
+    body: Ruma<get_pushrule::Request<'_>>,
+) -> ConduitResult<get_pushrule::Response> {
+    let sender_user = body.sender_user.as_ref().expect("user is authenticated");
+
+    let event = db
+        .account_data
+        .get::<ruma::events::push_rules::PushRulesEvent>(None, &sender_user, EventType::PushRules)?
+        .ok_or(Error::BadRequest(
+            ErrorKind::NotFound,
+            "PushRules event not found.",
+        ))?;
+
+    let global = event.content.global;
+    let rule = match body.kind {
+        RuleKind::Override => global
+            .override_
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map(|rule| rule.0.clone().into()),
+        RuleKind::Underride => global
+            .underride
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map(|rule| rule.0.clone().into()),
+        RuleKind::Sender => global
+            .sender
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map(|rule| rule.0.clone().into()),
+        RuleKind::Room => global
+            .room
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map(|rule| rule.0.clone().into()),
+        RuleKind::Content => global
+            .content
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map(|rule| rule.0.clone().into()),
+        RuleKind::_Custom(_) => None,
+    };
+
+    if let Some(rule) = rule {
+        Ok(get_pushrule::Response { rule }.into())
+    } else {
+        Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.").into())
+    }
+}
+
+#[cfg_attr(
+    feature = "conduit_bin",
+    put("/_matrix/client/r0/pushrules/<_>/<_>/<_>", data = "<body>")
+)]
 pub async fn set_pushrule_route(
     db: State<'_, Database>,
-    //body: Ruma<set_pushrule::Request>,
+    body: Ruma<set_pushrule::Request<'_>>,
 ) -> ConduitResult<set_pushrule::Response> {
-    // TODO
-    warn!("TODO: set_pushrule_route");
+    let sender_user = body.sender_user.as_ref().expect("user is authenticated");
+
+    if body.scope != "global" {
+        return Err(Error::BadRequest(
+            ErrorKind::InvalidParam,
+            "Scopes other than 'global' are not supported.",
+        ));
+    }
+
+    let mut event = db
+        .account_data
+        .get::<ruma::events::push_rules::PushRulesEvent>(None, &sender_user, EventType::PushRules)?
+        .ok_or(Error::BadRequest(
+            ErrorKind::NotFound,
+            "PushRules event not found.",
+        ))?;
+
+    let global = &mut event.content.global;
+    match body.kind {
+        RuleKind::Override => {
+            if let Some(rule) = global
+                .override_
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.override_.remove(&rule);
+            }
+
+            global.override_.insert(OverridePushRule(
+                ConditionalPushRuleInit {
+                    actions: body.actions.clone(),
+                    default: false,
+                    enabled: true,
+                    rule_id: body.rule_id.clone(),
+                    conditions: body.conditions.clone(),
+                }
+                .into(),
+            ));
+        }
+        RuleKind::Underride => {
+            if let Some(rule) = global
+                .underride
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.underride.remove(&rule);
+            }
+
+            global.underride.insert(UnderridePushRule(
+                ConditionalPushRuleInit {
+                    actions: body.actions.clone(),
+                    default: false,
+                    enabled: true,
+                    rule_id: body.rule_id.clone(),
+                    conditions: body.conditions.clone(),
+                }
+                .into(),
+            ));
+        }
+        RuleKind::Sender => {
+            if let Some(rule) = global
+                .sender
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.sender.remove(&rule);
+            }
+
+            global.sender.insert(SenderPushRule(
+                SimplePushRuleInit {
+                    actions: body.actions.clone(),
+                    default: false,
+                    enabled: true,
+                    rule_id: body.rule_id.clone(),
+                }
+                .into(),
+            ));
+        }
+        RuleKind::Room => {
+            if let Some(rule) = global
+                .room
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.room.remove(&rule);
+            }
+
+            global.room.insert(RoomPushRule(
+                SimplePushRuleInit {
+                    actions: body.actions.clone(),
+                    default: false,
+                    enabled: true,
+                    rule_id: body.rule_id.clone(),
+                }
+                .into(),
+            ));
+        }
+        RuleKind::Content => {
+            if let Some(rule) = global
+                .content
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.content.remove(&rule);
+            }
+
+            global.content.insert(ContentPushRule(
+                PatternedPushRuleInit {
+                    actions: body.actions.clone(),
+                    default: false,
+                    enabled: true,
+                    rule_id: body.rule_id.clone(),
+                    pattern: body.pattern.clone().unwrap_or_default(),
+                }
+                .into(),
+            ));
+        }
+        RuleKind::_Custom(_) => {}
+    }
+
+    db.account_data.update(
+        None,
+        &sender_user,
+        EventType::PushRules,
+        &event,
+        &db.globals,
+    )?;
 
     db.flush().await?;
 
@@ -54,19 +245,426 @@ pub async fn set_pushrule_route(
 
 #[cfg_attr(
     feature = "conduit_bin",
-    put("/_matrix/client/r0/pushrules/<_>/<_>/<_>/enabled")
+    get("/_matrix/client/r0/pushrules/<_>/<_>/<_>/actions", data = "<body>")
+)]
+pub async fn get_pushrule_actions_route(
+    db: State<'_, Database>,
+    body: Ruma<get_pushrule_actions::Request<'_>>,
+) -> ConduitResult<get_pushrule_actions::Response> {
+    let sender_user = body.sender_user.as_ref().expect("user is authenticated");
+
+    if body.scope != "global" {
+        return Err(Error::BadRequest(
+            ErrorKind::InvalidParam,
+            "Scopes other than 'global' are not supported.",
+        ));
+    }
+
+    let mut event = db
+        .account_data
+        .get::<ruma::events::push_rules::PushRulesEvent>(None, &sender_user, EventType::PushRules)?
+        .ok_or(Error::BadRequest(
+            ErrorKind::NotFound,
+            "PushRules event not found.",
+        ))?;
+
+    let global = &mut event.content.global;
+    let actions = match body.kind {
+        RuleKind::Override => global
+            .override_
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map(|rule| rule.0.actions.clone()),
+        RuleKind::Underride => global
+            .underride
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map(|rule| rule.0.actions.clone()),
+        RuleKind::Sender => global
+            .sender
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map(|rule| rule.0.actions.clone()),
+        RuleKind::Room => global
+            .room
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map(|rule| rule.0.actions.clone()),
+        RuleKind::Content => global
+            .content
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map(|rule| rule.0.actions.clone()),
+        RuleKind::_Custom(_) => None,
+    };
+
+    db.flush().await?;
+
+    Ok(get_pushrule_actions::Response {
+        actions: actions.unwrap_or_default(),
+    }
+    .into())
+}
+
+#[cfg_attr(
+    feature = "conduit_bin",
+    put("/_matrix/client/r0/pushrules/<_>/<_>/<_>/actions", data = "<body>")
+)]
+pub async fn set_pushrule_actions_route(
+    db: State<'_, Database>,
+    body: Ruma<set_pushrule_actions::Request<'_>>,
+) -> ConduitResult<set_pushrule_actions::Response> {
+    let sender_user = body.sender_user.as_ref().expect("user is authenticated");
+
+    if body.scope != "global" {
+        return Err(Error::BadRequest(
+            ErrorKind::InvalidParam,
+            "Scopes other than 'global' are not supported.",
+        ));
+    }
+
+    let mut event = db
+        .account_data
+        .get::<ruma::events::push_rules::PushRulesEvent>(None, &sender_user, EventType::PushRules)?
+        .ok_or(Error::BadRequest(
+            ErrorKind::NotFound,
+            "PushRules event not found.",
+        ))?;
+
+    let global = &mut event.content.global;
+    match body.kind {
+        RuleKind::Override => {
+            if let Some(mut rule) = global
+                .override_
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.override_.remove(&rule);
+                rule.0.actions = body.actions.clone();
+                global.override_.insert(rule);
+            }
+        }
+        RuleKind::Underride => {
+            if let Some(mut rule) = global
+                .underride
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.underride.remove(&rule);
+                rule.0.actions = body.actions.clone();
+                global.underride.insert(rule);
+            }
+        }
+        RuleKind::Sender => {
+            if let Some(mut rule) = global
+                .sender
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.sender.remove(&rule);
+                rule.0.actions = body.actions.clone();
+                global.sender.insert(rule);
+            }
+        }
+        RuleKind::Room => {
+            if let Some(mut rule) = global
+                .room
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.room.remove(&rule);
+                rule.0.actions = body.actions.clone();
+                global.room.insert(rule);
+            }
+        }
+        RuleKind::Content => {
+            if let Some(mut rule) = global
+                .content
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.content.remove(&rule);
+                rule.0.actions = body.actions.clone();
+                global.content.insert(rule);
+            }
+        }
+        RuleKind::_Custom(_) => {}
+    };
+
+    db.account_data.update(
+        None,
+        &sender_user,
+        EventType::PushRules,
+        &event,
+        &db.globals,
+    )?;
+
+    db.flush().await?;
+
+    Ok(set_pushrule_actions::Response.into())
+}
+
+#[cfg_attr(
+    feature = "conduit_bin",
+    get("/_matrix/client/r0/pushrules/<_>/<_>/<_>/enabled", data = "<body>")
+)]
+pub async fn get_pushrule_enabled_route(
+    db: State<'_, Database>,
+    body: Ruma<get_pushrule_enabled::Request<'_>>,
+) -> ConduitResult<get_pushrule_enabled::Response> {
+    let sender_user = body.sender_user.as_ref().expect("user is authenticated");
+
+    if body.scope != "global" {
+        return Err(Error::BadRequest(
+            ErrorKind::InvalidParam,
+            "Scopes other than 'global' are not supported.",
+        ));
+    }
+
+    let mut event = db
+        .account_data
+        .get::<ruma::events::push_rules::PushRulesEvent>(None, &sender_user, EventType::PushRules)?
+        .ok_or(Error::BadRequest(
+            ErrorKind::NotFound,
+            "PushRules event not found.",
+        ))?;
+
+    let global = &mut event.content.global;
+    let enabled = match body.kind {
+        RuleKind::Override => global
+            .override_
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map_or(false, |rule| rule.0.enabled),
+        RuleKind::Underride => global
+            .underride
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map_or(false, |rule| rule.0.enabled),
+        RuleKind::Sender => global
+            .sender
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map_or(false, |rule| rule.0.enabled),
+        RuleKind::Room => global
+            .room
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map_or(false, |rule| rule.0.enabled),
+        RuleKind::Content => global
+            .content
+            .iter()
+            .find(|rule| rule.0.rule_id == body.rule_id)
+            .map_or(false, |rule| rule.0.enabled),
+        RuleKind::_Custom(_) => false,
+    };
+
+    db.flush().await?;
+
+    Ok(get_pushrule_enabled::Response { enabled }.into())
+}
+
+#[cfg_attr(
+    feature = "conduit_bin",
+    put("/_matrix/client/r0/pushrules/<_>/<_>/<_>/enabled", data = "<body>")
 )]
 pub async fn set_pushrule_enabled_route(
     db: State<'_, Database>,
+    body: Ruma<set_pushrule_enabled::Request<'_>>,
 ) -> ConduitResult<set_pushrule_enabled::Response> {
-    // TODO
-    warn!("TODO: set_pushrule_enabled_route");
+    let sender_user = body.sender_user.as_ref().expect("user is authenticated");
+
+    if body.scope != "global" {
+        return Err(Error::BadRequest(
+            ErrorKind::InvalidParam,
+            "Scopes other than 'global' are not supported.",
+        ));
+    }
+
+    let mut event = db
+        .account_data
+        .get::<ruma::events::push_rules::PushRulesEvent>(None, &sender_user, EventType::PushRules)?
+        .ok_or(Error::BadRequest(
+            ErrorKind::NotFound,
+            "PushRules event not found.",
+        ))?;
+
+    let global = &mut event.content.global;
+    match body.kind {
+        RuleKind::Override => {
+            if let Some(mut rule) = global
+                .override_
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.override_.remove(&rule);
+                rule.0.enabled = body.enabled;
+                global.override_.insert(rule);
+            }
+        }
+        RuleKind::Underride => {
+            if let Some(mut rule) = global
+                .underride
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.underride.remove(&rule);
+                rule.0.enabled = body.enabled;
+                global.underride.insert(rule);
+            }
+        }
+        RuleKind::Sender => {
+            if let Some(mut rule) = global
+                .sender
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.sender.remove(&rule);
+                rule.0.enabled = body.enabled;
+                global.sender.insert(rule);
+            }
+        }
+        RuleKind::Room => {
+            if let Some(mut rule) = global
+                .room
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.room.remove(&rule);
+                rule.0.enabled = body.enabled;
+                global.room.insert(rule);
+            }
+        }
+        RuleKind::Content => {
+            if let Some(mut rule) = global
+                .content
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.content.remove(&rule);
+                rule.0.enabled = body.enabled;
+                global.content.insert(rule);
+            }
+        }
+        RuleKind::_Custom(_) => {}
+    }
+
+    db.account_data.update(
+        None,
+        &sender_user,
+        EventType::PushRules,
+        &event,
+        &db.globals,
+    )?;
 
     db.flush().await?;
 
     Ok(set_pushrule_enabled::Response.into())
 }
 
+#[cfg_attr(
+    feature = "conduit_bin",
+    delete("/_matrix/client/r0/pushrules/<_>/<_>/<_>", data = "<body>")
+)]
+pub async fn delete_pushrule_route(
+    db: State<'_, Database>,
+    body: Ruma<delete_pushrule::Request<'_>>,
+) -> ConduitResult<delete_pushrule::Response> {
+    let sender_user = body.sender_user.as_ref().expect("user is authenticated");
+
+    if body.scope != "global" {
+        return Err(Error::BadRequest(
+            ErrorKind::InvalidParam,
+            "Scopes other than 'global' are not supported.",
+        ));
+    }
+
+    let mut event = db
+        .account_data
+        .get::<ruma::events::push_rules::PushRulesEvent>(None, &sender_user, EventType::PushRules)?
+        .ok_or(Error::BadRequest(
+            ErrorKind::NotFound,
+            "PushRules event not found.",
+        ))?;
+
+    let global = &mut event.content.global;
+    match body.kind {
+        RuleKind::Override => {
+            if let Some(rule) = global
+                .override_
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.override_.remove(&rule);
+            }
+        }
+        RuleKind::Underride => {
+            if let Some(rule) = global
+                .underride
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.underride.remove(&rule);
+            }
+        }
+        RuleKind::Sender => {
+            if let Some(rule) = global
+                .sender
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.sender.remove(&rule);
+            }
+        }
+        RuleKind::Room => {
+            if let Some(rule) = global
+                .room
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.room.remove(&rule);
+            }
+        }
+        RuleKind::Content => {
+            if let Some(rule) = global
+                .content
+                .iter()
+                .find(|rule| rule.0.rule_id == body.rule_id)
+                .cloned()
+            {
+                global.content.remove(&rule);
+            }
+        }
+        RuleKind::_Custom(_) => {}
+    }
+
+    db.account_data.update(
+        None,
+        &sender_user,
+        EventType::PushRules,
+        &event,
+        &db.globals,
+    )?;
+
+    db.flush().await?;
+
+    Ok(delete_pushrule::Response.into())
+}
+
 #[cfg_attr(feature = "conduit_bin", get("/_matrix/client/r0/pushers"))]
 pub async fn get_pushers_route() -> ConduitResult<get_pushers::Response> {
     Ok(get_pushers::Response {
diff --git a/src/main.rs b/src/main.rs
index 9c0eab65..93ab5605 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -55,7 +55,12 @@ fn setup_rocket() -> rocket::Rocket {
                 client_server::get_capabilities_route,
                 client_server::get_pushrules_all_route,
                 client_server::set_pushrule_route,
+                client_server::get_pushrule_route,
                 client_server::set_pushrule_enabled_route,
+                client_server::get_pushrule_enabled_route,
+                client_server::get_pushrule_actions_route,
+                client_server::set_pushrule_actions_route,
+                client_server::delete_pushrule_route,
                 client_server::get_room_event_route,
                 client_server::get_filter_route,
                 client_server::create_filter_route,