feat: incoming invites over federation

This commit is contained in:
Timo Kösters 2021-04-11 21:01:27 +02:00
parent b0ea692706
commit 8773e5013d
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
10 changed files with 307 additions and 146 deletions

36
Cargo.lock generated
View file

@ -1625,7 +1625,7 @@ dependencies = [
[[package]]
name = "ruma"
version = "0.0.2"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"assign",
"js_int",
@ -1645,7 +1645,7 @@ dependencies = [
[[package]]
name = "ruma-api"
version = "0.17.0-alpha.2"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"http",
"percent-encoding",
@ -1660,7 +1660,7 @@ dependencies = [
[[package]]
name = "ruma-api-macros"
version = "0.17.0-alpha.2"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"proc-macro-crate",
"proc-macro2",
@ -1671,7 +1671,7 @@ dependencies = [
[[package]]
name = "ruma-appservice-api"
version = "0.2.0-alpha.2"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"ruma-api",
"ruma-common",
@ -1685,7 +1685,7 @@ dependencies = [
[[package]]
name = "ruma-client-api"
version = "0.10.0-alpha.2"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"assign",
"http",
@ -1704,7 +1704,7 @@ dependencies = [
[[package]]
name = "ruma-common"
version = "0.3.1"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"indexmap",
"js_int",
@ -1720,7 +1720,7 @@ dependencies = [
[[package]]
name = "ruma-events"
version = "0.22.0-alpha.2"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"js_int",
"ruma-common",
@ -1734,7 +1734,7 @@ dependencies = [
[[package]]
name = "ruma-events-macros"
version = "0.22.0-alpha.2"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"proc-macro-crate",
"proc-macro2",
@ -1745,7 +1745,7 @@ dependencies = [
[[package]]
name = "ruma-federation-api"
version = "0.1.0-alpha.1"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"js_int",
"ruma-api",
@ -1760,7 +1760,7 @@ dependencies = [
[[package]]
name = "ruma-identifiers"
version = "0.18.1"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"paste",
"rand",
@ -1774,7 +1774,7 @@ dependencies = [
[[package]]
name = "ruma-identifiers-macros"
version = "0.18.1"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"proc-macro2",
"quote",
@ -1785,12 +1785,12 @@ dependencies = [
[[package]]
name = "ruma-identifiers-validation"
version = "0.2.2"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
[[package]]
name = "ruma-identity-service-api"
version = "0.0.1"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"ruma-api",
"ruma-common",
@ -1803,7 +1803,7 @@ dependencies = [
[[package]]
name = "ruma-push-gateway-api"
version = "0.0.1"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"js_int",
"ruma-api",
@ -1818,7 +1818,7 @@ dependencies = [
[[package]]
name = "ruma-serde"
version = "0.3.1"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"form_urlencoded",
"itoa",
@ -1831,7 +1831,7 @@ dependencies = [
[[package]]
name = "ruma-serde-macros"
version = "0.3.1"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"proc-macro-crate",
"proc-macro2",
@ -1842,7 +1842,7 @@ dependencies = [
[[package]]
name = "ruma-signatures"
version = "0.6.0"
source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1"
source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4"
dependencies = [
"base64 0.13.0",
"ring",
@ -2120,7 +2120,7 @@ checksum = "3015a7d0a5fd5105c91c3710d42f9ccf0abfb287d62206484dcc67f9569a6483"
[[package]]
name = "state-res"
version = "0.1.0"
source = "git+https://github.com/timokoesters/state-res?rev=1ec42ea2fc0b0728bf027a5899839ad94bb3091b#1ec42ea2fc0b0728bf027a5899839ad94bb3091b"
source = "git+https://github.com/timokoesters/state-res?rev=2e90b36babeb0d6b99ce8d4b513302a25dcdffc1#2e90b36babeb0d6b99ce8d4b513302a25dcdffc1"
dependencies = [
"itertools 0.10.0",
"log",

View file

@ -18,12 +18,12 @@ rocket = { git = "https://github.com/SergioBenitez/Rocket.git", rev = "93e62c86e
#rocket = { git = "https://github.com/timokoesters/Rocket.git", branch = "empty_parameters", default-features = false, features = ["tls"] }
# Used for matrix spec type definitions and helpers
ruma = { git = "https://github.com/ruma/ruma", rev = "a310ccc318a4eb51062923d570d5a86c1468e8a1", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "unstable-pre-spec", "unstable-exhaustive-types"] }
#ruma = { git = "https://github.com/DevinR528/ruma", features = ["rand", "client-api", "federation-api", "push-gateway-api", "unstable-exhaustive-types", "unstable-pre-spec", "unstable-synapse-quirks"], branch = "verified-export" }
#ruma = { path = "../ruma/ruma", features = ["unstable-exhaustive-types", "rand", "client-api", "federation-api", "push-gateway-api", "unstable-pre-spec", "unstable-synapse-quirks"] }
#ruma = { git = "https://github.com/ruma/ruma", rev = "a310ccc318a4eb51062923d570d5a86c1468e8a1", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "unstable-pre-spec", "unstable-exhaustive-types"] }
ruma = { git = "https://github.com/timokoesters/ruma", rev = "b11de1e1f9d3c15267d09617131cf217f8277fa4", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "unstable-pre-spec", "unstable-exhaustive-types"] }
#ruma = { path = "../ruma/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "unstable-pre-spec", "unstable-exhaustive-types"] }
# Used when doing state resolution
state-res = { git = "https://github.com/timokoesters/state-res", rev = "1ec42ea2fc0b0728bf027a5899839ad94bb3091b", features = ["unstable-pre-spec"] }
state-res = { git = "https://github.com/timokoesters/state-res", rev = "2e90b36babeb0d6b99ce8d4b513302a25dcdffc1", features = ["unstable-pre-spec"] }
#state-res = { path = "../state-res", features = ["unstable-pre-spec"] }
# Used for long polling and federation sender, should be the same as rocket::tokio

View file

@ -617,11 +617,11 @@ pub async fn deactivate_route(
}
// Leave all joined rooms and reject all invitations
for room_id in db
.rooms
.rooms_joined(&sender_user)
.chain(db.rooms.rooms_invited(&sender_user))
{
for room_id in db.rooms.rooms_joined(&sender_user).chain(
db.rooms
.rooms_invited(&sender_user)
.map(|t| t.map(|(r, _)| r)),
) {
let room_id = room_id?;
let event = member::MemberEventContent {
membership: member::MembershipState::Leave,

View file

@ -599,6 +599,8 @@ async fn join_room_by_id_helper(
Error::BadServerResponse("Invalid user id in send_join response.")
})?;
let invite_state = Vec::new(); // TODO add a few important events
// Update our membership info, we do this here incase a user is invited
// and immediately leaves we need the DB to record the invite event for auth
db.rooms.update_membership(
@ -616,6 +618,7 @@ async fn join_room_by_id_helper(
Error::BadServerResponse("Invalid membership state content.")
})?,
&pdu.sender,
Some(invite_state),
&db.account_data,
&db.globals,
)?;

View file

@ -588,44 +588,23 @@ pub async fn sync_events_route(
}
let mut invited_rooms = BTreeMap::new();
for room_id in db.rooms.rooms_invited(&sender_user) {
let room_id = room_id?;
let mut invited_since_last_sync = false;
for pdu in db.rooms.pdus_since(&sender_user, &room_id, since)? {
let (_, pdu) = pdu?;
if pdu.kind == EventType::RoomMember && pdu.state_key == Some(sender_user.to_string()) {
let content = serde_json::from_value::<
Raw<ruma::events::room::member::MemberEventContent>,
>(pdu.content.clone())
.expect("Raw::from_value always works")
.deserialize()
.map_err(|_| Error::bad_database("Invalid PDU in database."))?;
for result in db.rooms.rooms_invited(&sender_user) {
let (room_id, invite_state_events) = result?;
let invite_count = db.rooms.get_invite_count(&room_id, &sender_user)?;
if content.membership == MembershipState::Invite {
invited_since_last_sync = true;
break;
}
}
}
if !invited_since_last_sync {
// Invited before last sync
if Some(since) >= invite_count {
continue;
}
let invited_room = sync_events::InvitedRoom {
invited_rooms.insert(
room_id.clone(),
sync_events::InvitedRoom {
invite_state: sync_events::InviteState {
events: db
.rooms
.room_state_full(&room_id)?
.into_iter()
.map(|(_, pdu)| pdu.to_stripped_state_event())
.collect(),
events: invite_state_events,
},
};
if !invited_room.is_empty() {
invited_rooms.insert(room_id.clone(), invited_room);
}
},
);
}
for user_id in left_encrypted_users {

View file

@ -161,8 +161,8 @@ impl Database {
userroomid_joined: db.open_tree("userroomid_joined")?,
roomuserid_joined: db.open_tree("roomuserid_joined")?,
roomuseroncejoinedids: db.open_tree("roomuseroncejoinedids")?,
userroomid_invited: db.open_tree("userroomid_invited")?,
roomuserid_invited: db.open_tree("roomuserid_invited")?,
userroomid_invitestate: db.open_tree("userroomid_invitestate")?,
roomuserid_invitecount: db.open_tree("roomuserid_invitecount")?,
userroomid_left: db.open_tree("userroomid_left")?,
statekey_shortstatekey: db.open_tree("statekey_shortstatekey")?,
@ -236,7 +236,11 @@ impl Database {
);
futures.push(self.rooms.userroomid_joined.watch_prefix(&userid_prefix));
futures.push(self.rooms.userroomid_invited.watch_prefix(&userid_prefix));
futures.push(
self.rooms
.userroomid_invitestate
.watch_prefix(&userid_prefix),
);
futures.push(self.rooms.userroomid_left.watch_prefix(&userid_prefix));
// Events for rooms we are in

View file

@ -11,10 +11,10 @@ use ruma::{
events::{
ignored_user_list,
room::{create::CreateEventContent, member, message},
EventType,
AnyStrippedStateEvent, EventType,
},
serde::{to_canonical_value, CanonicalJsonObject, CanonicalJsonValue, Raw},
EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId,
uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId,
};
use sled::IVec;
use state_res::{Event, StateMap};
@ -51,8 +51,8 @@ pub struct Rooms {
pub(super) userroomid_joined: sled::Tree,
pub(super) roomuserid_joined: sled::Tree,
pub(super) roomuseroncejoinedids: sled::Tree,
pub(super) userroomid_invited: sled::Tree,
pub(super) roomuserid_invited: sled::Tree,
pub(super) userroomid_invitestate: sled::Tree,
pub(super) roomuserid_invitecount: sled::Tree,
pub(super) userroomid_left: sled::Tree,
/// Remember the current state hash of a room.
@ -145,12 +145,12 @@ impl Rooms {
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))]
pub fn state_get(
pub fn state_get_id(
&self,
shortstatehash: u64,
event_type: &EventType,
state_key: &str,
) -> Result<Option<PduEvent>> {
) -> Result<Option<EventId>> {
let mut key = event_type.as_ref().as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(&state_key.as_bytes());
@ -161,7 +161,8 @@ impl Rooms {
let mut stateid = shortstatehash.to_be_bytes().to_vec();
stateid.extend_from_slice(&shortstatekey);
self.stateid_shorteventid
Ok(self
.stateid_shorteventid
.get(&stateid)?
.map(|bytes| self.shorteventid_eventid.get(&bytes).ok().flatten())
.flatten()
@ -178,13 +179,24 @@ impl Rooms {
)
})
.map(|r| r.ok())
.flatten()
.map_or(Ok(None), |event_id| self.get_pdu(&event_id))
.flatten())
} else {
Ok(None)
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))]
pub fn state_get(
&self,
shortstatehash: u64,
event_type: &EventType,
state_key: &str,
) -> Result<Option<PduEvent>> {
self.state_get_id(shortstatehash, event_type, state_key)?
.map_or(Ok(None), |event_id| self.get_pdu(&event_id))
}
/// Returns the state hash for this pdu.
#[tracing::instrument(skip(self))]
pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
@ -354,6 +366,21 @@ impl Rooms {
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))]
pub fn room_state_get_id(
&self,
room_id: &RoomId,
event_type: &EventType,
state_key: &str,
) -> Result<Option<EventId>> {
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? {
self.state_get_id(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
#[tracing::instrument(skip(self))]
pub fn room_state_get(
@ -395,7 +422,7 @@ impl Rooms {
}
/// Returns the json of a pdu.
pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<serde_json::Value>> {
pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map_or_else::<Result<_>, _, _>(
@ -666,12 +693,8 @@ impl Rooms {
// if the state_key fails
let target_user_id = UserId::try_from(state_key.clone())
.expect("This state_key was previously validated");
// Update our membership info, we do this here incase a user is invited
// and immediately leaves we need the DB to record the invite event for auth
self.update_membership(
&pdu.room_id,
&target_user_id,
serde_json::from_value::<member::MembershipState>(
let membership = serde_json::from_value::<member::MembershipState>(
pdu.content
.get("membership")
.ok_or_else(|| {
@ -687,8 +710,47 @@ impl Rooms {
ErrorKind::InvalidParam,
"Invalid membership state content.",
)
})?,
})?;
let invite_state = match membership {
member::MembershipState::Invite => {
let mut state = Vec::new();
// Add recommended events
if let Some(e) =
self.room_state_get(&pdu.room_id, &EventType::RoomJoinRules, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) = self.room_state_get(
&pdu.room_id,
&EventType::RoomCanonicalAlias,
"",
)? {
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.room_state_get(&pdu.room_id, &EventType::RoomAvatar, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.room_state_get(&pdu.room_id, &EventType::RoomName, "")?
{
state.push(e.to_stripped_state_event());
}
Some(state)
}
_ => None,
};
// Update our membership info, we do this here incase a user is invited
// and immediately leaves we need the DB to record the invite event for auth
self.update_membership(
&pdu.room_id,
&target_user_id,
membership,
&pdu.sender,
invite_state,
&db.account_data,
&db.globals,
)?;
@ -1044,10 +1106,10 @@ impl Rooms {
// Our depth is the maximum depth of prev_events + 1
let depth = prev_events
.iter()
.filter_map(|event_id| Some(self.get_pdu_json(event_id).ok()??.get("depth")?.as_u64()?))
.filter_map(|event_id| Some(self.get_pdu(event_id).ok()??.depth))
.max()
.unwrap_or(0_u64)
+ 1;
.unwrap_or(uint!(0))
+ uint!(1);
let mut unsigned = unsigned.unwrap_or_default();
if let Some(state_key) = &state_key {
@ -1071,9 +1133,7 @@ impl Rooms {
content,
state_key,
prev_events,
depth: depth
.try_into()
.map_err(|_| Error::bad_database("Depth is invalid"))?,
depth,
auth_events: auth_events
.iter()
.map(|(_, pdu)| pdu.event_id.clone())
@ -1384,6 +1444,7 @@ impl Rooms {
user_id: &UserId,
membership: member::MembershipState,
sender: &UserId,
invite_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
account_data: &super::account_data::AccountData,
globals: &super::globals::Globals,
) -> Result<()> {
@ -1487,8 +1548,8 @@ impl Rooms {
self.roomserverids.insert(&roomserver_id, &[])?;
self.userroomid_joined.insert(&userroom_id, &[])?;
self.roomuserid_joined.insert(&roomuser_id, &[])?;
self.userroomid_invited.remove(&userroom_id)?;
self.roomuserid_invited.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
self.userroomid_left.remove(&userroom_id)?;
}
member::MembershipState::Invite => {
@ -1508,8 +1569,13 @@ impl Rooms {
}
self.roomserverids.insert(&roomserver_id, &[])?;
self.userroomid_invited.insert(&userroom_id, &[])?;
self.roomuserid_invited.insert(&roomuser_id, &[])?;
self.userroomid_invitestate.insert(
&userroom_id,
serde_json::to_vec(&invite_state.unwrap_or_default())
.expect("state to bytes always works"),
)?;
self.roomuserid_invitecount
.insert(&roomuser_id, &globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_left.remove(&userroom_id)?;
@ -1526,8 +1592,8 @@ impl Rooms {
self.userroomid_left.insert(&userroom_id, &[])?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invited.remove(&userroom_id)?;
self.roomuserid_invited.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
}
_ => {}
}
@ -1797,7 +1863,7 @@ impl Rooms {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff);
self.roomuserid_invited
self.roomuserid_invitecount
.scan_prefix(prefix)
.keys()
.map(|key| {
@ -1816,6 +1882,22 @@ impl Rooms {
})
}
/// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self))]
pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_invitecount
.get(key)?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid invitecount in db.")
})?))
})
}
/// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self))]
pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<RoomId>> {
@ -1840,18 +1922,18 @@ impl Rooms {
/// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self))]
pub fn rooms_invited(&self, user_id: &UserId) -> impl Iterator<Item = Result<RoomId>> {
pub fn rooms_invited(
&self,
user_id: &UserId,
) -> impl Iterator<Item = Result<(RoomId, Vec<Raw<AnyStrippedStateEvent>>)>> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff);
self.userroomid_invited
.scan_prefix(prefix)
.keys()
.map(|key| {
Ok(RoomId::try_from(
self.userroomid_invitestate.scan_prefix(prefix).map(|r| {
let (key, state) = r?;
let room_id = RoomId::try_from(
utils::string_from_bytes(
&key?
.rsplit(|&b| b == 0xff)
&key.rsplit(|&b| b == 0xff)
.next()
.expect("rsplit always returns an element"),
)
@ -1859,7 +1941,12 @@ impl Rooms {
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok((room_id, state))
})
}
@ -1906,7 +1993,7 @@ impl Rooms {
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_invited.get(userroom_id)?.is_some())
Ok(self.userroomid_invitestate.get(userroom_id)?.is_some())
}
pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {

View file

@ -167,6 +167,7 @@ fn setup_rocket() -> (rocket::Rocket, Config) {
server_server::get_event_route,
server_server::get_missing_events_route,
server_server::get_room_state_ids_route,
server_server::create_invite_route,
server_server::get_profile_information_route,
],
)

View file

@ -10,20 +10,24 @@ use ruma::{
federation::{
directory::{get_public_rooms, get_public_rooms_filtered},
discovery::{
get_remote_server_keys, get_server_keys,
get_server_version::v1 as get_server_version, ServerSigningKeys, VerifyKey,
get_remote_server_keys, get_server_keys, get_server_version, ServerSigningKeys,
VerifyKey,
},
event::{get_event, get_missing_events, get_room_state_ids},
membership::create_invite,
query::get_profile_information,
transactions::send_transaction_message,
},
OutgoingRequest,
},
directory::{IncomingFilter, IncomingRoomNetwork},
events::{room::create::CreateEventContent, EventType},
events::{
room::{create::CreateEventContent, member::MembershipState},
EventType,
},
serde::{to_canonical_value, Raw},
signatures::CanonicalJsonValue,
EventId, RoomId, ServerName, ServerSigningKeyId, UserId,
EventId, RoomId, RoomVersionId, ServerName, ServerSigningKeyId, UserId,
};
use state_res::{Event, EventMap, StateMap};
use std::{
@ -332,13 +336,13 @@ pub async fn request_well_known(
#[tracing::instrument(skip(db))]
pub fn get_server_version_route(
db: State<'_, Database>,
) -> ConduitResult<get_server_version::Response> {
) -> ConduitResult<get_server_version::v1::Response> {
if !db.globals.allow_federation() {
return Err(Error::bad_config("Federation is disabled."));
}
Ok(get_server_version::Response {
server: Some(get_server_version::Server {
Ok(get_server_version::v1::Response {
server: Some(get_server_version::v1::Server {
name: Some("Conduit".to_owned()),
version: Some(env!("CARGO_PKG_VERSION").to_owned()),
}),
@ -1406,12 +1410,9 @@ pub fn get_event_route<'a>(
origin: db.globals.server_name().to_owned(),
origin_server_ts: SystemTime::now(),
pdu: PduEvent::convert_to_outgoing_federation_event(
serde_json::from_value(
db.rooms
.get_pdu_json(&body.event_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?,
)
.map_err(|_| Error::bad_database("Invalid pdu in database."))?,
),
}
.into())
@ -1438,9 +1439,10 @@ pub fn get_missing_events_route<'a>(
if let Some(pdu) = db.rooms.get_pdu_json(&queued_events[i])? {
if body.earliest_events.contains(
&serde_json::from_value(
pdu.get("event_id")
.cloned()
.ok_or_else(|| Error::bad_database("Event in db has no event_id field."))?,
serde_json::to_value(pdu.get("event_id").cloned().ok_or_else(|| {
Error::bad_database("Event in db has no event_id field.")
})?)
.expect("canonical json is valid json value"),
)
.map_err(|_| Error::bad_database("Invalid event_id field in pdu in db."))?,
) {
@ -1449,16 +1451,14 @@ pub fn get_missing_events_route<'a>(
}
queued_events.extend_from_slice(
&serde_json::from_value::<Vec<EventId>>(
pdu.get("prev_events").cloned().ok_or_else(|| {
Error::bad_database("Invalid prev_events field of pdu in db.")
})?,
serde_json::to_value(pdu.get("prev_events").cloned().ok_or_else(|| {
Error::bad_database("Event in db has no prev_events field.")
})?)
.expect("canonical json is valid json value"),
)
.map_err(|_| Error::bad_database("Invalid prev_events content in pdu in db."))?,
);
events.push(PduEvent::convert_to_outgoing_federation_event(
serde_json::from_value(pdu)
.map_err(|_| Error::bad_database("Invalid pdu in database."))?,
));
events.push(PduEvent::convert_to_outgoing_federation_event(pdu));
}
i += 1;
}
@ -1518,6 +1518,93 @@ pub fn get_room_state_ids_route<'a>(
.into())
}
#[cfg_attr(
feature = "conduit_bin",
put("/_matrix/federation/v2/invite/<_>/<_>", data = "<body>")
)]
#[tracing::instrument(skip(db, body))]
pub fn create_invite_route<'a>(
db: State<'a, Database>,
body: Ruma<create_invite::v2::Request>,
) -> ConduitResult<create_invite::v2::Response> {
if body.room_version < RoomVersionId::Version6 {
return Err(Error::BadRequest(
ErrorKind::IncompatibleRoomVersion {
room_version: body.room_version.clone(),
},
"Server does not support this room version.",
));
}
let mut signed_event = utils::to_canonical_object(&body.event)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?;
ruma::signatures::hash_and_sign_event(
db.globals.server_name().as_str(),
db.globals.keypair(),
&mut signed_event,
&body.room_version,
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?;
let sender = serde_json::from_value(
serde_json::to_value(
signed_event
.get("sender")
.ok_or_else(|| {
Error::BadRequest(ErrorKind::InvalidParam, "Event had no sender field.")
})?
.clone(),
)
.expect("CanonicalJsonValue to serde_json::Value always works"),
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user id."))?;
let invited_user = serde_json::from_value(
serde_json::to_value(
signed_event
.get("state_key")
.ok_or_else(|| {
Error::BadRequest(ErrorKind::InvalidParam, "Event had no state_key field.")
})?
.clone(),
)
.expect("CanonicalJsonValue to serde_json::Value always works"),
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "state_key is not a user id."))?;
let mut invite_state = body.invite_room_state.clone();
let mut event = serde_json::from_str::<serde_json::Map<String, serde_json::Value>>(
&body.event.json().to_string(),
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?;
event.insert("event_id".to_owned(), "$dummy".into());
invite_state.push(
serde_json::from_value::<PduEvent>(event.into())
.map_err(|e| {
warn!("Invalid invite event: {}", e);
Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event.")
})?
.to_stripped_state_event(),
);
db.rooms.update_membership(
&body.room_id,
&invited_user,
MembershipState::Invite,
&sender,
Some(invite_state),
&db.account_data,
&db.globals,
)?;
Ok(create_invite::v2::Response {
event: PduEvent::convert_to_outgoing_federation_event(signed_event),
}
.into())
}
#[cfg_attr(
feature = "conduit_bin",
get("/_matrix/federation/v1/query/profile", data = "<body>")