fixup! Get required keys in batch when joining a room

This commit is contained in:
Kurt Roeckx 2021-08-29 13:25:20 +02:00 committed by Jonas Zohren
parent 61adef8f2c
commit a8d181e00e
2 changed files with 33 additions and 24 deletions

View file

@ -227,7 +227,11 @@ impl Globals {
/// Remove the outdated keys and insert the new ones. /// Remove the outdated keys and insert the new ones.
/// ///
/// This doesn't actually check that the keys provided are newer than the old set. /// This doesn't actually check that the keys provided are newer than the old set.
pub fn add_signing_key(&self, origin: &ServerName, new_keys: ServerSigningKeys) -> Result<()> { pub fn add_signing_key(
&self,
origin: &ServerName,
new_keys: ServerSigningKeys,
) -> Result<BTreeMap<ServerSigningKeyId, VerifyKey>> {
// Not atomic, but this is not critical // Not atomic, but this is not critical
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
@ -252,7 +256,14 @@ impl Globals {
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
)?; )?;
Ok(()) let mut tree = keys.verify_keys;
tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
Ok(tree)
} }
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.

View file

@ -3279,13 +3279,18 @@ pub(crate) async fn fetch_required_signing_keys(
Ok(()) Ok(())
} }
pub fn get_missing_signing_keys_for_pdus( // Gets a list of servers for which we don't have the signing key yet. We go over
// the PDUs and either cache the key or add it to the list that needs to be retrieved.
fn get_missing_servers_for_pdus(
pdus: &Vec<Raw<Pdu>>, pdus: &Vec<Raw<Pdu>>,
servers: &mut BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId, QueryCriteria>>, servers: &mut BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId, QueryCriteria>>,
room_version: &RoomVersionId, room_version: &RoomVersionId,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>,
db: &Database, db: &Database,
) -> Result<()> { ) -> Result<()> {
let mut pkm = pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?;
for pdu in pdus { for pdu in pdus {
let value = serde_json::from_str::<CanonicalJsonObject>(pdu.json().get()).map_err(|e| { let value = serde_json::from_str::<CanonicalJsonObject>(pdu.json().get()).map_err(|e| {
error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); error!("Invalid PDU in server response: {:?}: {:?}", pdu, e);
@ -3342,6 +3347,10 @@ pub fn get_missing_signing_keys_for_pdus(
Error::BadServerResponse("Invalid servername in signatures of server response pdu.") Error::BadServerResponse("Invalid servername in signatures of server response pdu.")
})?; })?;
if servers.contains_key(origin) {
continue;
}
trace!("Loading signing keys for {}", origin); trace!("Loading signing keys for {}", origin);
let result = db let result = db
@ -3359,10 +3368,7 @@ pub fn get_missing_signing_keys_for_pdus(
); );
} }
pub_key_map pkm.insert(origin.to_string(), result);
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(origin.to_string(), result);
} }
} }
@ -3378,14 +3384,14 @@ pub async fn fetch_join_signing_keys(
let mut servers = let mut servers =
BTreeMap::<Box<ServerName>, BTreeMap<ServerSigningKeyId, QueryCriteria>>::new(); BTreeMap::<Box<ServerName>, BTreeMap<ServerSigningKeyId, QueryCriteria>>::new();
get_missing_signing_keys_for_pdus( get_missing_servers_for_pdus(
&event.room_state.state, &event.room_state.state,
&mut servers, &mut servers,
&room_version, &room_version,
&pub_key_map, &pub_key_map,
&db, &db,
)?; )?;
get_missing_signing_keys_for_pdus( get_missing_servers_for_pdus(
&event.room_state.auth_chain, &event.room_state.auth_chain,
&mut servers, &mut servers,
&room_version, &room_version,
@ -3424,23 +3430,19 @@ pub async fn fetch_join_signing_keys(
.await .await
{ {
trace!("Got signing keys: {:?}", keys); trace!("Got signing keys: {:?}", keys);
let mut pkm = pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?;
for k in keys.server_keys { for k in keys.server_keys {
// TODO: Check signature // TODO: Check signature
servers.remove(&k.server_name); servers.remove(&k.server_name);
db.globals.add_signing_key(&k.server_name, k.clone())?; let result = db.globals.add_signing_key(&k.server_name, k.clone())?
let result = db
.globals
.signing_keys_for(&k.server_name)?
.into_iter() .into_iter()
.map(|(k, v)| (k.to_string(), v.key)) .map(|(k, v)| (k.to_string(), v.key))
.collect::<BTreeMap<_, _>>(); .collect::<BTreeMap<_, _>>();
pub_key_map pkm.insert(k.server_name.to_string(), result);
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(k.server_name.to_string(), result);
} }
} }
if servers.is_empty() { if servers.is_empty() {
@ -3457,12 +3459,8 @@ pub async fn fetch_join_signing_keys(
if let Ok(get_keys_response) = result { if let Ok(get_keys_response) = result {
// TODO: We should probably not trust the server_name in the response. // TODO: We should probably not trust the server_name in the response.
let server = &get_keys_response.server_key.server_name; let server = &get_keys_response.server_key.server_name;
db.globals let result = db.globals
.add_signing_key(server, get_keys_response.server_key.clone())?; .add_signing_key(server, get_keys_response.server_key.clone())?
let result = db
.globals
.signing_keys_for(server)?
.into_iter() .into_iter()
.map(|(k, v)| (k.to_string(), v.key)) .map(|(k, v)| (k.to_string(), v.key))
.collect::<BTreeMap<_, _>>(); .collect::<BTreeMap<_, _>>();