From a8d181e00ee4631bdcd5c6b1a73fe80a308c56c3 Mon Sep 17 00:00:00 2001 From: Kurt Roeckx Date: Sun, 29 Aug 2021 13:25:20 +0200 Subject: [PATCH] fixup! Get required keys in batch when joining a room --- src/database/globals.rs | 15 +++++++++++++-- src/server_server.rs | 42 ++++++++++++++++++++--------------------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/database/globals.rs b/src/database/globals.rs index 6d11f496..048b9b89 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -227,7 +227,11 @@ impl Globals { /// Remove the outdated keys and insert the new ones. /// /// This doesn't actually check that the keys provided are newer than the old set. - pub fn add_signing_key(&self, origin: &ServerName, new_keys: ServerSigningKeys) -> Result<()> { + pub fn add_signing_key( + &self, + origin: &ServerName, + new_keys: ServerSigningKeys, + ) -> Result> { // Not atomic, but this is not critical let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; @@ -252,7 +256,14 @@ impl Globals { &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), )?; - Ok(()) + let mut tree = keys.verify_keys; + tree.extend( + keys.old_verify_keys + .into_iter() + .map(|old| (old.0, VerifyKey::new(old.1.key))), + ); + + Ok(tree) } /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. diff --git a/src/server_server.rs b/src/server_server.rs index 129f5958..2a3665f1 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -3279,13 +3279,18 @@ pub(crate) async fn fetch_required_signing_keys( Ok(()) } -pub fn get_missing_signing_keys_for_pdus( +// Gets a list of servers for which we don't have the signing key yet. We go over +// the PDUs and either cache the key or add it to the list that needs to be retrieved. +fn get_missing_servers_for_pdus( pdus: &Vec>, servers: &mut BTreeMap, BTreeMap>, room_version: &RoomVersionId, pub_key_map: &RwLock>>, db: &Database, ) -> Result<()> { + let mut pkm = pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))?; for pdu in pdus { let value = serde_json::from_str::(pdu.json().get()).map_err(|e| { error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); @@ -3342,6 +3347,10 @@ pub fn get_missing_signing_keys_for_pdus( Error::BadServerResponse("Invalid servername in signatures of server response pdu.") })?; + if servers.contains_key(origin) { + continue; + } + trace!("Loading signing keys for {}", origin); let result = db @@ -3359,10 +3368,7 @@ pub fn get_missing_signing_keys_for_pdus( ); } - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(origin.to_string(), result); + pkm.insert(origin.to_string(), result); } } @@ -3378,14 +3384,14 @@ pub async fn fetch_join_signing_keys( let mut servers = BTreeMap::, BTreeMap>::new(); - get_missing_signing_keys_for_pdus( + get_missing_servers_for_pdus( &event.room_state.state, &mut servers, &room_version, &pub_key_map, &db, )?; - get_missing_signing_keys_for_pdus( + get_missing_servers_for_pdus( &event.room_state.auth_chain, &mut servers, &room_version, @@ -3424,23 +3430,19 @@ pub async fn fetch_join_signing_keys( .await { trace!("Got signing keys: {:?}", keys); + let mut pkm = pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))?; for k in keys.server_keys { // TODO: Check signature servers.remove(&k.server_name); - db.globals.add_signing_key(&k.server_name, k.clone())?; - - let result = db - .globals - .signing_keys_for(&k.server_name)? + let result = db.globals.add_signing_key(&k.server_name, k.clone())? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect::>(); - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(k.server_name.to_string(), result); + pkm.insert(k.server_name.to_string(), result); } } if servers.is_empty() { @@ -3457,12 +3459,8 @@ pub async fn fetch_join_signing_keys( if let Ok(get_keys_response) = result { // TODO: We should probably not trust the server_name in the response. let server = &get_keys_response.server_key.server_name; - db.globals - .add_signing_key(server, get_keys_response.server_key.clone())?; - - let result = db - .globals - .signing_keys_for(server)? + let result = db.globals + .add_signing_key(server, get_keys_response.server_key.clone())? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect::>();