improvement: use u64s in auth chain cache

This commit is contained in:
Timo Kösters 2021-08-11 19:15:38 +02:00
parent 096e0971f1
commit c2c6a8673e
No known key found for this signature in database
GPG key ID: 356E705610F626D5
2 changed files with 131 additions and 142 deletions

View file

@ -88,7 +88,7 @@ pub struct Rooms {
pub(super) referencedevents: Arc<dyn Tree>, pub(super) referencedevents: Arc<dyn Tree>,
pub(super) pdu_cache: Mutex<LruCache<EventId, Arc<PduEvent>>>, pub(super) pdu_cache: Mutex<LruCache<EventId, Arc<PduEvent>>>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<EventId>, HashSet<EventId>>>, pub(super) auth_chain_cache: Mutex<LruCache<u64, HashSet<u64>>>,
} }
impl Rooms { impl Rooms {
@ -315,19 +315,7 @@ impl Rooms {
); );
let (shortstatehash, already_existed) = let (shortstatehash, already_existed) =
match self.statehash_shortstatehash.get(&state_hash)? { self.get_or_create_shortstatehash(&state_hash, &db.globals)?;
Some(shortstatehash) => (
utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
true,
),
None => {
let shortstatehash = db.globals.next_count()?;
self.statehash_shortstatehash
.insert(&state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
}
};
let new_state = if !already_existed { let new_state = if !already_existed {
let mut new_state = HashSet::new(); let mut new_state = HashSet::new();
@ -352,25 +340,14 @@ impl Rooms {
} }
}; };
let shorteventid = let shorteventid = self
match self.eventid_shorteventid.get(eventid.as_bytes()).ok()? { .get_or_create_shorteventid(&eventid, &db.globals)
Some(shorteventid) => shorteventid.to_vec(),
None => {
let shorteventid = db.globals.next_count().ok()?;
self.eventid_shorteventid
.insert(eventid.as_bytes(), &shorteventid.to_be_bytes())
.ok()?; .ok()?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), eventid.as_bytes())
.ok()?;
shorteventid.to_be_bytes().to_vec()
}
};
let mut state_id = shortstatehash.to_be_bytes().to_vec(); let mut state_id = shortstatehash.to_be_bytes().to_vec();
state_id.extend_from_slice(&shortstatekey); state_id.extend_from_slice(&shortstatekey);
Some((state_id, shorteventid)) Some((state_id, shorteventid.to_be_bytes().to_vec()))
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -428,6 +405,61 @@ impl Rooms {
Ok(()) Ok(())
} }
/// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(
&self,
state_hash: &StateHashId,
globals: &super::globals::Globals,
) -> Result<(u64, bool)> {
Ok(match self.statehash_shortstatehash.get(&state_hash)? {
Some(shortstatehash) => (
utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
true,
),
None => {
let shortstatehash = globals.next_count()?;
self.statehash_shortstatehash
.insert(&state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
}
})
}
/// Returns (shortstatehash, already_existed)
pub fn get_or_create_shorteventid(
&self,
event_id: &EventId,
globals: &super::globals::Globals,
) -> Result<u64> {
Ok(match self.eventid_shorteventid.get(event_id.as_bytes())? {
Some(shorteventid) => utils::u64_from_bytes(&shorteventid)
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
None => {
let shorteventid = globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid
}
})
}
pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<EventId> {
let bytes = self
.shorteventid_eventid
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
EventId::try_from(
utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
}
/// Returns the full room state. /// Returns the full room state.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn room_state_full( pub fn room_state_full(
@ -1116,17 +1148,7 @@ impl Rooms {
state: &StateMap<Arc<PduEvent>>, state: &StateMap<Arc<PduEvent>>,
globals: &super::globals::Globals, globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
let shorteventid = match self.eventid_shorteventid.get(event_id.as_bytes())? { let shorteventid = self.get_or_create_shorteventid(&event_id, globals)?;
Some(shorteventid) => shorteventid.to_vec(),
None => {
let shorteventid = globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid.to_be_bytes().to_vec()
}
};
let state_hash = self.calculate_hash( let state_hash = self.calculate_hash(
&state &state
@ -1135,21 +1157,10 @@ impl Rooms {
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
); );
let shortstatehash = match self.statehash_shortstatehash.get(&state_hash)? { let (shortstatehash, already_existed) =
Some(shortstatehash) => { self.get_or_create_shortstatehash(&state_hash, globals)?;
// State already existed in db
self.shorteventid_shortstatehash
.insert(&shorteventid, &*shortstatehash)?;
return Ok(());
}
None => {
let shortstatehash = globals.next_count()?;
self.statehash_shortstatehash
.insert(&state_hash, &shortstatehash.to_be_bytes())?;
shortstatehash.to_be_bytes().to_vec()
}
};
if !already_existed {
let batch = state let batch = state
.iter() .iter()
.filter_map(|((event_type, state_key), pdu)| { .filter_map(|((event_type, state_key), pdu)| {
@ -1168,36 +1179,23 @@ impl Rooms {
} }
}; };
let shorteventid = match self let shorteventid = self
.eventid_shorteventid .get_or_create_shorteventid(&pdu.event_id, globals)
.get(pdu.event_id.as_bytes())
.ok()?
{
Some(shorteventid) => shorteventid.to_vec(),
None => {
let shorteventid = globals.next_count().ok()?;
self.eventid_shorteventid
.insert(pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())
.ok()?; .ok()?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), pdu.event_id.as_bytes())
.ok()?;
shorteventid.to_be_bytes().to_vec()
}
};
let mut state_id = shortstatehash.clone(); let mut state_id = shortstatehash.to_be_bytes().to_vec();
state_id.extend_from_slice(&shortstatekey); state_id.extend_from_slice(&shortstatekey);
Some((state_id, shorteventid)) Some((state_id, shorteventid.to_be_bytes().to_vec()))
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.stateid_shorteventid self.stateid_shorteventid
.insert_batch(&mut batch.into_iter())?; .insert_batch(&mut batch.into_iter())?;
}
self.shorteventid_shortstatehash self.shorteventid_shortstatehash
.insert(&shorteventid, &*shortstatehash)?; .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(()) Ok(())
} }
@ -1212,26 +1210,16 @@ impl Rooms {
new_pdu: &PduEvent, new_pdu: &PduEvent,
globals: &super::globals::Globals, globals: &super::globals::Globals,
) -> Result<u64> { ) -> Result<u64> {
let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?;
let old_state = if let Some(old_shortstatehash) = let old_state = if let Some(old_shortstatehash) =
self.roomid_shortstatehash.get(new_pdu.room_id.as_bytes())? self.roomid_shortstatehash.get(new_pdu.room_id.as_bytes())?
{ {
// Store state for event. The state does not include the event itself. // Store state for event. The state does not include the event itself.
// Instead it's the state before the pdu, so the room's old state. // Instead it's the state before the pdu, so the room's old state.
let shorteventid = match self.eventid_shorteventid.get(new_pdu.event_id.as_bytes())? {
Some(shorteventid) => shorteventid.to_vec(),
None => {
let shorteventid = globals.next_count()?;
self.eventid_shorteventid
.insert(new_pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), new_pdu.event_id.as_bytes())?;
shorteventid.to_be_bytes().to_vec()
}
};
self.shorteventid_shortstatehash self.shorteventid_shortstatehash
.insert(&shorteventid, &old_shortstatehash)?; .insert(&shorteventid.to_be_bytes(), &old_shortstatehash)?;
if new_pdu.state_key.is_none() { if new_pdu.state_key.is_none() {
return utils::u64_from_bytes(&old_shortstatehash).map_err(|_| { return utils::u64_from_bytes(&old_shortstatehash).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash.") Error::bad_database("Invalid shortstatehash in roomid_shortstatehash.")
@ -1264,19 +1252,7 @@ impl Rooms {
} }
}; };
let shorteventid = match self.eventid_shorteventid.get(new_pdu.event_id.as_bytes())? { new_state.insert(shortstatekey, shorteventid.to_be_bytes().to_vec());
Some(shorteventid) => shorteventid.to_vec(),
None => {
let shorteventid = globals.next_count()?;
self.eventid_shorteventid
.insert(new_pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), new_pdu.event_id.as_bytes())?;
shorteventid.to_be_bytes().to_vec()
}
};
new_state.insert(shortstatekey, shorteventid);
let new_state_hash = self.calculate_hash( let new_state_hash = self.calculate_hash(
&new_state &new_state
@ -1516,11 +1492,7 @@ impl Rooms {
); );
// Generate short event id // Generate short event id
let shorteventid = db.globals.next_count()?; let _shorteventid = self.get_or_create_shorteventid(&pdu.event_id, &db.globals)?;
self.eventid_shorteventid
.insert(pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), pdu.event_id.as_bytes())?;
// We append to state before appending the pdu, so we don't have a moment in time with the // We append to state before appending the pdu, so we don't have a moment in time with the
// pdu without it's state. This is okay because append_pdu can't fail. // pdu without it's state. This is okay because append_pdu can't fail.
@ -2655,9 +2627,7 @@ impl Rooms {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn auth_chain_cache( pub fn auth_chain_cache(&self) -> std::sync::MutexGuard<'_, LruCache<u64, HashSet<u64>>> {
&self,
) -> std::sync::MutexGuard<'_, LruCache<Vec<EventId>, HashSet<EventId>>> {
self.auth_chain_cache.lock().unwrap() self.auth_chain_cache.lock().unwrap()
} }
} }

View file

@ -1044,13 +1044,16 @@ pub fn handle_incoming_pdu<'a>(
if incoming_pdu.prev_events.len() == 1 { if incoming_pdu.prev_events.len() == 1 {
let prev_event = &incoming_pdu.prev_events[0]; let prev_event = &incoming_pdu.prev_events[0];
let state = db let prev_event_sstatehash = db
.rooms .rooms
.pdu_shortstatehash(prev_event) .pdu_shortstatehash(prev_event)
.map_err(|_| "Failed talking to db".to_owned())? .map_err(|_| "Failed talking to db".to_owned())?;
.map(|shortstatehash| db.rooms.state_full_ids(shortstatehash).ok())
.flatten(); let state =
if let Some(state) = state { prev_event_sstatehash.map(|shortstatehash| db.rooms.state_full_ids(shortstatehash));
if let Some(Ok(state)) = state {
warn!("Using cached state");
let mut state = fetch_and_handle_events( let mut state = fetch_and_handle_events(
db, db,
origin, origin,
@ -1088,6 +1091,7 @@ pub fn handle_incoming_pdu<'a>(
} }
if state_at_incoming_event.is_none() { if state_at_incoming_event.is_none() {
warn!("Calling /state_ids");
// Call /state_ids to find out what the state at this pdu is. We trust the server's // Call /state_ids to find out what the state at this pdu is. We trust the server's
// response to some extend, but we still do a lot of checks on the events // response to some extend, but we still do a lot of checks on the events
match db match db
@ -1755,35 +1759,50 @@ fn append_incoming_pdu(
fn get_auth_chain(starting_events: Vec<EventId>, db: &Database) -> Result<HashSet<EventId>> { fn get_auth_chain(starting_events: Vec<EventId>, db: &Database) -> Result<HashSet<EventId>> {
let mut full_auth_chain = HashSet::new(); let mut full_auth_chain = HashSet::new();
let starting_events = starting_events
.iter()
.map(|id| {
(db.rooms
.get_or_create_shorteventid(id, &db.globals)
.map(|s| (s, id)))
})
.collect::<Result<Vec<_>>>()?;
let mut cache = db.rooms.auth_chain_cache(); let mut cache = db.rooms.auth_chain_cache();
for event_id in &starting_events { for (sevent_id, event_id) in starting_events {
if let Some(cached) = cache.get_mut(&[event_id.clone()][..]) { if let Some(cached) = cache.get_mut(&sevent_id) {
full_auth_chain.extend(cached.iter().cloned()); full_auth_chain.extend(cached.iter().cloned());
} else { } else {
drop(cache); drop(cache);
let mut auth_chain = HashSet::new(); let mut auth_chain = HashSet::new();
get_auth_chain_recursive(&event_id, &mut auth_chain, db)?; get_auth_chain_recursive(&event_id, &mut auth_chain, db)?;
cache = db.rooms.auth_chain_cache(); cache = db.rooms.auth_chain_cache();
cache.insert(vec![event_id.clone()], auth_chain.clone()); cache.insert(sevent_id, auth_chain.clone());
full_auth_chain.extend(auth_chain); full_auth_chain.extend(auth_chain);
}; };
} }
Ok(full_auth_chain) full_auth_chain
.into_iter()
.map(|sid| db.rooms.get_eventid_from_short(sid))
.collect()
} }
fn get_auth_chain_recursive( fn get_auth_chain_recursive(
event_id: &EventId, event_id: &EventId,
found: &mut HashSet<EventId>, found: &mut HashSet<u64>,
db: &Database, db: &Database,
) -> Result<()> { ) -> Result<()> {
let r = db.rooms.get_pdu(&event_id); let r = db.rooms.get_pdu(&event_id);
match r { match r {
Ok(Some(pdu)) => { Ok(Some(pdu)) => {
for auth_event in &pdu.auth_events { for auth_event in &pdu.auth_events {
if !found.contains(auth_event) { let sauthevent = db
found.insert(auth_event.clone()); .rooms
.get_or_create_shorteventid(auth_event, &db.globals)?;
if !found.contains(&sauthevent) {
found.insert(sauthevent);
get_auth_chain_recursive(&auth_event, found, db)?; get_auth_chain_recursive(&auth_event, found, db)?;
} }
} }