diff --git a/src/database.rs b/src/database.rs index 1bf9434f..7996057d 100644 --- a/src/database.rs +++ b/src/database.rs @@ -272,6 +272,7 @@ impl Database { referencedevents: builder.open_tree("referencedevents")?, pdu_cache: Mutex::new(LruCache::new(100_000)), auth_chain_cache: Mutex::new(LruCache::new(100_000)), + shorteventid_cache: Mutex::new(LruCache::new(100_000)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, diff --git a/src/database/rooms.rs b/src/database/rooms.rs index c53fa9ea..246aa0ba 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -89,6 +89,7 @@ pub struct Rooms { pub(super) pdu_cache: Mutex<LruCache<EventId, Arc<PduEvent>>>, pub(super) auth_chain_cache: Mutex<LruCache<u64, HashSet<u64>>>, + pub(super) shorteventid_cache: Mutex<LruCache<u64, EventId>>, } impl Rooms { @@ -447,17 +448,28 @@ impl Rooms { } pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<EventId> { + if let Some(id) = self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid) { + return Ok(id.clone()); + } + let bytes = self .shorteventid_eventid .get(&shorteventid.to_be_bytes())? .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; - EventId::try_from( + let event_id = EventId::try_from( utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") })?, ) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) + .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))?; + + self.shorteventid_cache + .lock() + .unwrap() + .insert(shorteventid, event_id.clone()); + + Ok(event_id) } /// Returns the full room state. diff --git a/src/server_server.rs b/src/server_server.rs index 23c80ee6..0b9a7e68 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1311,7 +1311,7 @@ pub fn handle_incoming_pdu<'a>( for state in fork_states { auth_chain_sets.push( get_auth_chain(state.iter().map(|(_, id)| id.clone()).collect(), db) - .map_err(|_| "Failed to load auth chain.".to_owned())?, + .map_err(|_| "Failed to load auth chain.".to_owned())?.collect(), ); } @@ -1756,15 +1756,15 @@ fn append_incoming_pdu( Ok(pdu_id) } -fn get_auth_chain(starting_events: Vec<EventId>, db: &Database) -> Result<HashSet<EventId>> { +fn get_auth_chain(starting_events: Vec<EventId>, db: &Database) -> Result<impl Iterator<Item = EventId> + '_> { let mut full_auth_chain = HashSet::new(); let starting_events = starting_events .iter() .map(|id| { - (db.rooms + db.rooms .get_or_create_shorteventid(id, &db.globals) - .map(|s| (s, id))) + .map(|s| (s, id)) }) .collect::<Result<Vec<_>>>()?; @@ -1783,10 +1783,11 @@ fn get_auth_chain(starting_events: Vec<EventId>, db: &Database) -> Result<HashSe }; } - full_auth_chain + drop(cache); + + Ok(full_auth_chain .into_iter() - .map(|sid| db.rooms.get_eventid_from_short(sid)) - .collect() + .filter_map(move |sid| db.rooms.get_eventid_from_short(sid).ok())) } fn get_auth_chain_recursive( @@ -1909,7 +1910,6 @@ pub fn get_event_authorization_route( Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids - .into_iter() .filter_map(|id| Some(db.rooms.get_pdu_json(&id).ok()??)) .map(|event| PduEvent::convert_to_outgoing_federation_event(event)) .collect(), @@ -1953,7 +1953,6 @@ pub fn get_room_state_route( Ok(get_room_state::v1::Response { auth_chain: auth_chain_ids - .into_iter() .map(|id| { Ok::<_, Error>(PduEvent::convert_to_outgoing_federation_event( db.rooms.get_pdu_json(&id)?.unwrap(), @@ -1996,7 +1995,7 @@ pub fn get_room_state_ids_route( let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?; Ok(get_room_state_ids::v1::Response { - auth_chain_ids: auth_chain_ids.into_iter().collect(), + auth_chain_ids: auth_chain_ids.collect(), pdu_ids, } .into()) @@ -2265,7 +2264,6 @@ pub async fn create_join_event_route( Ok(create_join_event::v2::Response { room_state: RoomState { auth_chain: auth_chain_ids - .iter() .filter_map(|id| db.rooms.get_pdu_json(&id).ok().flatten()) .map(PduEvent::convert_to_outgoing_federation_event) .collect(),