From 82e7f57b389d011bc8d80f9142f723b3cd1e1ad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sun, 14 Aug 2022 13:38:21 +0200 Subject: [PATCH] refactor state accessor, state cache, user, uiaa --- .../key_value/rooms/state_accessor.rs | 27 +-- src/database/key_value/rooms/state_cache.rs | 8 + src/database/key_value/rooms/user.rs | 20 +- src/database/key_value/uiaa.rs | 149 +------------ src/service/rooms/state_accessor/data.rs | 141 ++---------- src/service/rooms/state_accessor/mod.rs | 115 ++-------- src/service/rooms/state_cache/data.rs | 3 + src/service/rooms/state_cache/mod.rs | 65 +----- src/service/rooms/user/data.rs | 113 +--------- src/service/rooms/user/mod.rs | 104 ++------- src/service/uiaa/data.rs | 208 +----------------- src/service/uiaa/mod.rs | 96 +------- 12 files changed, 116 insertions(+), 933 deletions(-) create mode 100644 src/database/key_value/rooms/state_cache.rs create mode 100644 src/service/rooms/state_cache/data.rs diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs index ae26a7c4..db81967d 100644 --- a/src/database/key_value/rooms/state_accessor.rs +++ b/src/database/key_value/rooms/state_accessor.rs @@ -1,7 +1,5 @@ - /// Builds a StateMap by iterating over all keys that start - /// with state_hash, this gives the full state for the given state_hash. - #[tracing::instrument(skip(self))] - pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { +impl service::room::state_accessor::Data for KeyValueDatabase { + async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { let full_state = self .load_shortstatehash_info(shortstatehash)? .pop() @@ -21,8 +19,7 @@ Ok(result) } - #[tracing::instrument(skip(self))] - pub async fn state_full( + async fn state_full( &self, shortstatehash: u64, ) -> Result>> { @@ -59,8 +56,7 @@ } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn state_get_id( + fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, @@ -86,8 +82,7 @@ } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn state_get( + fn state_get( &self, shortstatehash: u64, event_type: &StateEventType, @@ -98,7 +93,7 @@ } /// Returns the state hash for this pdu. - pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { + fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { self.eventid_shorteventid .get(event_id.as_bytes())? .map_or(Ok(None), |shorteventid| { @@ -116,8 +111,7 @@ } /// Returns the full room state. - #[tracing::instrument(skip(self))] - pub async fn room_state_full( + async fn room_state_full( &self, room_id: &RoomId, ) -> Result>> { @@ -129,8 +123,7 @@ } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn room_state_get_id( + fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, @@ -144,8 +137,7 @@ } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn room_state_get( + fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, @@ -157,4 +149,3 @@ Ok(None) } } - diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs new file mode 100644 index 00000000..37814020 --- /dev/null +++ b/src/database/key_value/rooms/state_cache.rs @@ -0,0 +1,8 @@ +impl service::room::state_cache::Data for KeyValueDatabase { + fn mark_as_once_joined(user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xff); + userroom_id.extend_from_slice(room_id.as_bytes()); + self.roomuseroncejoinedids.insert(&userroom_id, &[])?; + } +} diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 976ab5b3..52145ced 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -1,6 +1,5 @@ - - #[tracing::instrument(skip(self))] - pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { +impl service::room::user::Data for KeyValueDatabase { + fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -13,8 +12,7 @@ Ok(()) } - #[tracing::instrument(skip(self))] - pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -28,8 +26,7 @@ .unwrap_or(Ok(0)) } - #[tracing::instrument(skip(self))] - pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -43,7 +40,7 @@ .unwrap_or(Ok(0)) } - pub fn associate_token_shortstatehash( + fn associate_token_shortstatehash( &self, room_id: &RoomId, token: u64, @@ -58,7 +55,7 @@ .insert(&key, &shortstatehash.to_be_bytes()) } - pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { + fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); let mut key = shortroomid.to_be_bytes().to_vec(); @@ -74,8 +71,7 @@ .transpose() } - #[tracing::instrument(skip(self))] - pub fn get_shared_rooms<'a>( + fn get_shared_rooms<'a>( &'a self, users: Vec>, ) -> Result>> + 'a> { @@ -111,4 +107,4 @@ .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) })) } - +} diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs index 12373139..4d1dac57 100644 --- a/src/database/key_value/uiaa.rs +++ b/src/database/key_value/uiaa.rs @@ -1,149 +1,4 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, RwLock}, -}; - -use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; -use ruma::{ - api::client::{ - error::ErrorKind, - uiaa::{ - AuthType, IncomingAuthData, IncomingPassword, - IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo, - }, - }, - signatures::CanonicalJsonValue, - DeviceId, UserId, -}; -use tracing::error; - -use super::abstraction::Tree; - -pub struct Uiaa { - pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication - pub(super) userdevicesessionid_uiaarequest: - RwLock, Box, String), CanonicalJsonValue>>, -} - -impl Uiaa { - /// Creates a new Uiaa session. Make sure the session token is unique. - pub fn create( - &self, - user_id: &UserId, - device_id: &DeviceId, - uiaainfo: &UiaaInfo, - json_body: &CanonicalJsonValue, - ) -> Result<()> { - self.set_uiaa_request( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?) - json_body, - )?; - self.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), - Some(uiaainfo), - ) - } - - pub fn try_auth( - &self, - user_id: &UserId, - device_id: &DeviceId, - auth: &IncomingAuthData, - uiaainfo: &UiaaInfo, - users: &super::users::Users, - globals: &super::globals::Globals, - ) -> Result<(bool, UiaaInfo)> { - let mut uiaainfo = auth - .session() - .map(|session| self.get_uiaa_session(user_id, device_id, session)) - .unwrap_or_else(|| Ok(uiaainfo.clone()))?; - - if uiaainfo.session.is_none() { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - } - - match auth { - // Find out what the user completed - IncomingAuthData::Password(IncomingPassword { - identifier, - password, - .. - }) => { - let username = match identifier { - UserIdOrLocalpart(username) => username, - _ => { - return Err(Error::BadRequest( - ErrorKind::Unrecognized, - "Identifier type not recognized.", - )) - } - }; - - let user_id = - UserId::parse_with_server_name(username.clone(), globals.server_name()) - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") - })?; - - // Check if password is correct - if let Some(hash) = users.password_hash(&user_id)? { - let hash_matches = - argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); - - if !hash_matches { - uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { - kind: ErrorKind::Forbidden, - message: "Invalid username or password.".to_owned(), - }); - return Ok((false, uiaainfo)); - } - } - - // Password was correct! Let's add it to `completed` - uiaainfo.completed.push(AuthType::Password); - } - IncomingAuthData::Dummy(_) => { - uiaainfo.completed.push(AuthType::Dummy); - } - k => error!("type not supported: {:?}", k), - } - - // Check if a flow now succeeds - let mut completed = false; - 'flows: for flow in &mut uiaainfo.flows { - for stage in &flow.stages { - if !uiaainfo.completed.contains(stage) { - continue 'flows; - } - } - // We didn't break, so this flow succeeded! - completed = true; - } - - if !completed { - self.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - Some(&uiaainfo), - )?; - return Ok((false, uiaainfo)); - } - - // UIAA was successful! Remove this session and return true - self.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - None, - )?; - Ok((true, uiaainfo)) - } - +impl service::uiaa::Data for KeyValueDatabase { fn set_uiaa_request( &self, user_id: &UserId, @@ -162,7 +17,7 @@ impl Uiaa { Ok(()) } - pub fn get_uiaa_request( + fn get_uiaa_request( &self, user_id: &UserId, device_id: &DeviceId, diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index ae26a7c4..a2b76e46 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -1,160 +1,51 @@ +pub trait Data { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. - #[tracing::instrument(skip(self))] - pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - let mut result = BTreeMap::new(); - let mut i = 0; - for compressed in full_state.into_iter() { - let parsed = self.parse_compressed_state_event(compressed)?; - result.insert(parsed.0, parsed.1); + async fn state_full_ids(&self, shortstatehash: u64) -> Result>>; - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - Ok(result) - } - - #[tracing::instrument(skip(self))] - pub async fn state_full( + async fn state_full( &self, shortstatehash: u64, - ) -> Result>> { - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - - let mut result = HashMap::new(); - let mut i = 0; - for compressed in full_state { - let (_, eventid) = self.parse_compressed_state_event(compressed)?; - if let Some(pdu) = self.get_pdu(&eventid)? { - result.insert( - ( - pdu.kind.to_string().into(), - pdu.state_key - .as_ref() - .ok_or_else(|| Error::bad_database("State event has no state key."))? - .clone(), - ), - pdu, - ); - } - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - - Ok(result) - } + ) -> Result>>; /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn state_get_id( + fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { - Some(s) => s, - None => return Ok(None), - }; - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - Ok(full_state - .into_iter() - .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - .and_then(|compressed| { - self.parse_compressed_state_event(compressed) - .ok() - .map(|(_, id)| id) - })) - } + ) -> Result>>; /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn state_get( + fn state_get( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| self.get_pdu(&event_id)) - } + ) -> Result>>; /// Returns the state hash for this pdu. - pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { - self.eventid_shorteventid - .get(event_id.as_bytes())? - .map_or(Ok(None), |shorteventid| { - self.shorteventid_shortstatehash - .get(&shorteventid)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database( - "Invalid shortstatehash bytes in shorteventid_shortstatehash", - ) - }) - }) - .transpose() - }) - } + fn pdu_shortstatehash(&self, event_id: &EventId) -> Result>; /// Returns the full room state. - #[tracing::instrument(skip(self))] - pub async fn room_state_full( + async fn room_state_full( &self, room_id: &RoomId, - ) -> Result>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_full(current_shortstatehash).await - } else { - Ok(HashMap::new()) - } - } + ) -> Result>>; /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn room_state_get_id( + fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_get_id(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } - } + ) -> Result>>; /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn room_state_get( + fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_get(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } - } - + ) -> Result>>; +} diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index ae26a7c4..28a49a98 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -1,24 +1,18 @@ +mod data; +pub use data::Data; + +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self))] pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - let mut result = BTreeMap::new(); - let mut i = 0; - for compressed in full_state.into_iter() { - let parsed = self.parse_compressed_state_event(compressed)?; - result.insert(parsed.0, parsed.1); - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - Ok(result) + self.db.state_full_ids(shortstatehash) } #[tracing::instrument(skip(self))] @@ -26,36 +20,7 @@ &self, shortstatehash: u64, ) -> Result>> { - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - - let mut result = HashMap::new(); - let mut i = 0; - for compressed in full_state { - let (_, eventid) = self.parse_compressed_state_event(compressed)?; - if let Some(pdu) = self.get_pdu(&eventid)? { - result.insert( - ( - pdu.kind.to_string().into(), - pdu.state_key - .as_ref() - .ok_or_else(|| Error::bad_database("State event has no state key."))? - .clone(), - ), - pdu, - ); - } - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - - Ok(result) + self.db.state_full(shortstatehash) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -66,23 +31,7 @@ event_type: &StateEventType, state_key: &str, ) -> Result>> { - let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { - Some(s) => s, - None => return Ok(None), - }; - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - Ok(full_state - .into_iter() - .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - .and_then(|compressed| { - self.parse_compressed_state_event(compressed) - .ok() - .map(|(_, id)| id) - })) + self.db.state_get_id(shortstatehash, event_type, state_key) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -93,26 +42,12 @@ event_type: &StateEventType, state_key: &str, ) -> Result>> { - self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| self.get_pdu(&event_id)) + self.db.pdu_state_get(event_id) } /// Returns the state hash for this pdu. pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { - self.eventid_shorteventid - .get(event_id.as_bytes())? - .map_or(Ok(None), |shorteventid| { - self.shorteventid_shortstatehash - .get(&shorteventid)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database( - "Invalid shortstatehash bytes in shorteventid_shortstatehash", - ) - }) - }) - .transpose() - }) + self.db.pdu_shortstatehash(event_id) } /// Returns the full room state. @@ -121,11 +56,7 @@ &self, room_id: &RoomId, ) -> Result>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_full(current_shortstatehash).await - } else { - Ok(HashMap::new()) - } + self.db.room_state_full(room_id) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -136,11 +67,7 @@ event_type: &StateEventType, state_key: &str, ) -> Result>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_get_id(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } + self.db.room_state_get_id(room_id, event_type, state_key) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -151,10 +78,6 @@ event_type: &StateEventType, state_key: &str, ) -> Result>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - self.state_get(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } + self.db.room_state_get(room_id, event_type, state_key) } - +} diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs new file mode 100644 index 00000000..166d4f6b --- /dev/null +++ b/src/service/rooms/state_cache/data.rs @@ -0,0 +1,3 @@ +pub trait Data { + fn mark_as_once_joined(user_id: &UserId, room_id: &RoomId) -> Result<()>; +} diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index e7f457e6..778679de 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,4 +1,13 @@ +mod data; +pub use data::Data; +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { /// Update current membership data. #[tracing::instrument(skip(self, last_state, db))] pub fn update_membership( @@ -25,10 +34,6 @@ serverroom_id.push(0xff); serverroom_id.extend_from_slice(room_id.as_bytes()); - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - let mut roomuser_id = room_id.as_bytes().to_vec(); roomuser_id.push(0xff); roomuser_id.extend_from_slice(user_id.as_bytes()); @@ -38,7 +43,7 @@ // Check if the user never joined this room if !self.once_joined(user_id, room_id)? { // Add the user ID to the join list then - self.roomuseroncejoinedids.insert(&userroom_id, &[])?; + self.db.mark_as_once_joined(user_id, room_id)?; // Check if the room has a predecessor if let Some(predecessor) = self @@ -116,10 +121,6 @@ } } - if update_joined_count { - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } self.userroomid_joined.insert(&userroom_id, &[])?; self.roomuserid_joined.insert(&roomuser_id, &[])?; self.userroomid_invitestate.remove(&userroom_id)?; @@ -150,10 +151,6 @@ return Ok(()); } - if update_joined_count { - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } self.userroomid_invitestate.insert( &userroom_id, &serde_json::to_vec(&last_state.unwrap_or_default()) @@ -167,16 +164,6 @@ self.roomuserid_leftcount.remove(&roomuser_id)?; } MembershipState::Leave | MembershipState::Ban => { - if update_joined_count - && self - .room_members(room_id) - .chain(self.room_members_invited(room_id)) - .filter_map(|r| r.ok()) - .all(|u| u.server_name() != user_id.server_name()) - { - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; - } self.userroomid_leftstate.insert( &userroom_id, &serde_json::to_vec(&Vec::>::new()).unwrap(), @@ -231,36 +218,6 @@ .unwrap() .insert(room_id.to_owned(), Arc::new(real_users)); - for old_joined_server in self.room_servers(room_id).filter_map(|r| r.ok()) { - if !joined_servers.remove(&old_joined_server) { - // Server not in room anymore - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xff); - roomserver_id.extend_from_slice(old_joined_server.as_bytes()); - - let mut serverroom_id = old_joined_server.as_bytes().to_vec(); - serverroom_id.push(0xff); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; - } - } - - // Now only new servers are in joined_servers anymore - for server in joined_servers { - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xff); - roomserver_id.extend_from_slice(server.as_bytes()); - - let mut serverroom_id = server.as_bytes().to_vec(); - serverroom_id.push(0xff); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } - self.appservice_in_room_cache .write() .unwrap() @@ -714,4 +671,4 @@ Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) } - +} diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index 976ab5b3..47a44eef 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,114 +1,21 @@ +pub trait Data { + fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; - #[tracing::instrument(skip(self))] - pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; - self.userroomid_notificationcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; - self.userroomid_highlightcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; + fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; - Ok(()) - } - - #[tracing::instrument(skip(self))] - pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_notificationcount - .get(&userroom_id)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid notification count in db.")) - }) - .unwrap_or(Ok(0)) - } - - #[tracing::instrument(skip(self))] - pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_highlightcount - .get(&userroom_id)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid highlight count in db.")) - }) - .unwrap_or(Ok(0)) - } - - pub fn associate_token_shortstatehash( + fn associate_token_shortstatehash( &self, room_id: &RoomId, token: u64, shortstatehash: u64, - ) -> Result<()> { - let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); + ) -> Result<()>; - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); + fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result>; - self.roomsynctoken_shortstatehash - .insert(&key, &shortstatehash.to_be_bytes()) - } - - pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); - - self.roomsynctoken_shortstatehash - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") - }) - }) - .transpose() - } - - #[tracing::instrument(skip(self))] - pub fn get_shared_rooms<'a>( + fn get_shared_rooms<'a>( &'a self, users: Vec>, - ) -> Result>> + 'a> { - let iterators = users.into_iter().map(move |user_id| { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.userroomid_joined - .scan_prefix(prefix) - .map(|(key, _)| { - let roomid_index = key - .iter() - .enumerate() - .find(|(_, &b)| b == 0xff) - .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? - .0 - + 1; // +1 because the room id starts AFTER the separator - - let room_id = key[roomid_index..].to_vec(); - - Ok::<_, Error>(room_id) - }) - .filter_map(|r| r.ok()) - }); - - // We use the default compare function because keys are sorted correctly (not reversed) - Ok(utils::common_elements(iterators, Ord::cmp) - .expect("users is not empty") - .map(|bytes| { - RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| { - Error::bad_database("Invalid RoomId bytes in userroomid_joined") - })?) - .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - })) - } - + ) -> Result>> + 'a>; +} diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 976ab5b3..45fb3551 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -1,46 +1,23 @@ +mod data; +pub use data::Data; - #[tracing::instrument(skip(self))] +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_notificationcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; - self.userroomid_highlightcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; - - Ok(()) + self.db.reset_notification_counts(user_id, room_id) } - #[tracing::instrument(skip(self))] pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_notificationcount - .get(&userroom_id)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid notification count in db.")) - }) - .unwrap_or(Ok(0)) + self.db.notification_count(user_id, room_id) } - #[tracing::instrument(skip(self))] pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_highlightcount - .get(&userroom_id)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid highlight count in db.")) - }) - .unwrap_or(Ok(0)) + self.db.highlight_count(user_id, room_id) } pub fn associate_token_shortstatehash( @@ -49,66 +26,17 @@ token: u64, shortstatehash: u64, ) -> Result<()> { - let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); - - self.roomsynctoken_shortstatehash - .insert(&key, &shortstatehash.to_be_bytes()) + self.db.associate_token_shortstatehash(user_id, room_id) } pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); - - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); - - self.roomsynctoken_shortstatehash - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") - }) - }) - .transpose() + self.db.get_token_shortstatehash(room_id, token) } - #[tracing::instrument(skip(self))] pub fn get_shared_rooms<'a>( &'a self, users: Vec>, ) -> Result>> + 'a> { - let iterators = users.into_iter().map(move |user_id| { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.userroomid_joined - .scan_prefix(prefix) - .map(|(key, _)| { - let roomid_index = key - .iter() - .enumerate() - .find(|(_, &b)| b == 0xff) - .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? - .0 - + 1; // +1 because the room id starts AFTER the separator - - let room_id = key[roomid_index..].to_vec(); - - Ok::<_, Error>(room_id) - }) - .filter_map(|r| r.ok()) - }); - - // We use the default compare function because keys are sorted correctly (not reversed) - Ok(utils::common_elements(iterators, Ord::cmp) - .expect("users is not empty") - .map(|bytes| { - RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| { - Error::bad_database("Invalid RoomId bytes in userroomid_joined") - })?) - .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - })) + self.db.get_shared_rooms(users) } - +} diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs index 12373139..40e69bda 100644 --- a/src/service/uiaa/data.rs +++ b/src/service/uiaa/data.rs @@ -1,179 +1,18 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, RwLock}, -}; - -use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; -use ruma::{ - api::client::{ - error::ErrorKind, - uiaa::{ - AuthType, IncomingAuthData, IncomingPassword, - IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo, - }, - }, - signatures::CanonicalJsonValue, - DeviceId, UserId, -}; -use tracing::error; - -use super::abstraction::Tree; - -pub struct Uiaa { - pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication - pub(super) userdevicesessionid_uiaarequest: - RwLock, Box, String), CanonicalJsonValue>>, -} - -impl Uiaa { - /// Creates a new Uiaa session. Make sure the session token is unique. - pub fn create( - &self, - user_id: &UserId, - device_id: &DeviceId, - uiaainfo: &UiaaInfo, - json_body: &CanonicalJsonValue, - ) -> Result<()> { - self.set_uiaa_request( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?) - json_body, - )?; - self.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), - Some(uiaainfo), - ) - } - - pub fn try_auth( - &self, - user_id: &UserId, - device_id: &DeviceId, - auth: &IncomingAuthData, - uiaainfo: &UiaaInfo, - users: &super::users::Users, - globals: &super::globals::Globals, - ) -> Result<(bool, UiaaInfo)> { - let mut uiaainfo = auth - .session() - .map(|session| self.get_uiaa_session(user_id, device_id, session)) - .unwrap_or_else(|| Ok(uiaainfo.clone()))?; - - if uiaainfo.session.is_none() { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - } - - match auth { - // Find out what the user completed - IncomingAuthData::Password(IncomingPassword { - identifier, - password, - .. - }) => { - let username = match identifier { - UserIdOrLocalpart(username) => username, - _ => { - return Err(Error::BadRequest( - ErrorKind::Unrecognized, - "Identifier type not recognized.", - )) - } - }; - - let user_id = - UserId::parse_with_server_name(username.clone(), globals.server_name()) - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") - })?; - - // Check if password is correct - if let Some(hash) = users.password_hash(&user_id)? { - let hash_matches = - argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); - - if !hash_matches { - uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { - kind: ErrorKind::Forbidden, - message: "Invalid username or password.".to_owned(), - }); - return Ok((false, uiaainfo)); - } - } - - // Password was correct! Let's add it to `completed` - uiaainfo.completed.push(AuthType::Password); - } - IncomingAuthData::Dummy(_) => { - uiaainfo.completed.push(AuthType::Dummy); - } - k => error!("type not supported: {:?}", k), - } - - // Check if a flow now succeeds - let mut completed = false; - 'flows: for flow in &mut uiaainfo.flows { - for stage in &flow.stages { - if !uiaainfo.completed.contains(stage) { - continue 'flows; - } - } - // We didn't break, so this flow succeeded! - completed = true; - } - - if !completed { - self.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - Some(&uiaainfo), - )?; - return Ok((false, uiaainfo)); - } - - // UIAA was successful! Remove this session and return true - self.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - None, - )?; - Ok((true, uiaainfo)) - } - +pub trait Data { fn set_uiaa_request( &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, - ) -> Result<()> { - self.userdevicesessionid_uiaarequest - .write() - .unwrap() - .insert( - (user_id.to_owned(), device_id.to_owned(), session.to_owned()), - request.to_owned(), - ); + ) -> Result<()>; - Ok(()) - } - - pub fn get_uiaa_request( + fn get_uiaa_request( &self, user_id: &UserId, device_id: &DeviceId, session: &str, - ) -> Option { - self.userdevicesessionid_uiaarequest - .read() - .unwrap() - .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) - .map(|j| j.to_owned()) - } + ) -> Option; fn update_uiaa_session( &self, @@ -181,47 +20,12 @@ impl Uiaa { device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>, - ) -> Result<()> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - if let Some(uiaainfo) = uiaainfo { - self.userdevicesessionid_uiaainfo.insert( - &userdevicesessionid, - &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), - )?; - } else { - self.userdevicesessionid_uiaainfo - .remove(&userdevicesessionid)?; - } - - Ok(()) - } + ) -> Result<()>; fn get_uiaa_session( &self, user_id: &UserId, device_id: &DeviceId, session: &str, - ) -> Result { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - serde_json::from_slice( - &self - .userdevicesessionid_uiaainfo - .get(&userdevicesessionid)? - .ok_or(Error::BadRequest( - ErrorKind::Forbidden, - "UIAA session does not exist.", - ))?, - ) - .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) - } + ) -> Result; } diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 12373139..593ea5f2 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,31 +1,13 @@ -use std::{ - collections::BTreeMap, - sync::{Arc, RwLock}, -}; +mod data; +pub use data::Data; -use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; -use ruma::{ - api::client::{ - error::ErrorKind, - uiaa::{ - AuthType, IncomingAuthData, IncomingPassword, - IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo, - }, - }, - signatures::CanonicalJsonValue, - DeviceId, UserId, -}; -use tracing::error; +use crate::service::*; -use super::abstraction::Tree; - -pub struct Uiaa { - pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication - pub(super) userdevicesessionid_uiaarequest: - RwLock, Box, String), CanonicalJsonValue>>, +pub struct Service { + db: D, } -impl Uiaa { +impl Service<_> { /// Creates a new Uiaa session. Make sure the session token is unique. pub fn create( &self, @@ -144,35 +126,13 @@ impl Uiaa { Ok((true, uiaainfo)) } - fn set_uiaa_request( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - request: &CanonicalJsonValue, - ) -> Result<()> { - self.userdevicesessionid_uiaarequest - .write() - .unwrap() - .insert( - (user_id.to_owned(), device_id.to_owned(), session.to_owned()), - request.to_owned(), - ); - - Ok(()) - } - pub fn get_uiaa_request( &self, user_id: &UserId, device_id: &DeviceId, session: &str, ) -> Option { - self.userdevicesessionid_uiaarequest - .read() - .unwrap() - .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) - .map(|j| j.to_owned()) + self.db.get_uiaa_request(user_id, device_id, session) } fn update_uiaa_session( @@ -182,46 +142,6 @@ impl Uiaa { session: &str, uiaainfo: Option<&UiaaInfo>, ) -> Result<()> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - if let Some(uiaainfo) = uiaainfo { - self.userdevicesessionid_uiaainfo.insert( - &userdevicesessionid, - &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), - )?; - } else { - self.userdevicesessionid_uiaainfo - .remove(&userdevicesessionid)?; - } - - Ok(()) - } - - fn get_uiaa_session( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - ) -> Result { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(session.as_bytes()); - - serde_json::from_slice( - &self - .userdevicesessionid_uiaainfo - .get(&userdevicesessionid)? - .ok_or(Error::BadRequest( - ErrorKind::Forbidden, - "UIAA session does not exist.", - ))?, - ) - .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) + self.db.update_uiaa_session(user_id, device_id, session, uiaainfo) } }