Get required keys in batch when joining a room

We now ask the trusted server for all keys in 1 request, instead of
asking each server individual for it's own keys.
This commit is contained in:
Kurt Roeckx 2021-08-25 16:02:01 +02:00 committed by Timo Kösters
parent 9c3f1a9272
commit a87519fb71
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
2 changed files with 221 additions and 21 deletions

View file

@ -5,7 +5,6 @@ use crate::{
server_server, utils, ConduitResult, Database, Error, Result, Ruma, server_server, utils, ConduitResult, Database, Error, Result, Ruma,
}; };
use member::{MemberEventContent, MembershipState}; use member::{MemberEventContent, MembershipState};
use rocket::futures;
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
@ -667,14 +666,19 @@ async fn join_room_by_id_helper(
let mut state = HashMap::new(); let mut state = HashMap::new();
let pub_key_map = RwLock::new(BTreeMap::new()); let pub_key_map = RwLock::new(BTreeMap::new());
for result in futures::future::join_all( server_server::fetch_join_signing_keys(
send_join_response &send_join_response,
&room_version,
&pub_key_map,
&db,
)
.await?;
for result in send_join_response
.room_state .room_state
.state .state
.iter() .iter()
.map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)), .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db))
)
.await
{ {
let (event_id, value) = match result { let (event_id, value) = match result {
Ok(t) => t, Ok(t) => t,
@ -723,14 +727,11 @@ async fn join_room_by_id_helper(
&db, &db,
)?; )?;
for result in futures::future::join_all( for result in send_join_response
send_join_response
.room_state .room_state
.auth_chain .auth_chain
.iter() .iter()
.map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db)), .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, &db))
)
.await
{ {
let (event_id, value) = match result { let (event_id, value) = match result {
Ok(t) => t, Ok(t) => t,
@ -787,7 +788,7 @@ async fn join_room_by_id_helper(
Ok(join_room_by_id::Response::new(room_id.clone()).into()) Ok(join_room_by_id::Response::new(room_id.clone()).into())
} }
async fn validate_and_add_event_id( fn validate_and_add_event_id(
pdu: &Raw<Pdu>, pdu: &Raw<Pdu>,
room_version: &RoomVersionId, room_version: &RoomVersionId,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>,
@ -830,7 +831,6 @@ async fn validate_and_add_event_id(
} }
} }
server_server::fetch_required_signing_keys(&value, pub_key_map, db).await?;
if let Err(e) = ruma::signatures::verify_event( if let Err(e) = ruma::signatures::verify_event(
&*pub_key_map &*pub_key_map
.read() .read()

View file

@ -6,7 +6,7 @@ use crate::{
use get_profile_information::v1::ProfileField; use get_profile_information::v1::ProfileField;
use http::header::{HeaderValue, AUTHORIZATION}; use http::header::{HeaderValue, AUTHORIZATION};
use regex::Regex; use regex::Regex;
use rocket::response::content::Json; use rocket::{futures, response::content::Json};
use ruma::{ use ruma::{
api::{ api::{
client::error::{Error as RumaError, ErrorKind}, client::error::{Error as RumaError, ErrorKind},
@ -15,8 +15,9 @@ use ruma::{
device::get_devices::{self, v1::UserDevice}, device::get_devices::{self, v1::UserDevice},
directory::{get_public_rooms, get_public_rooms_filtered}, directory::{get_public_rooms, get_public_rooms_filtered},
discovery::{ discovery::{
get_remote_server_keys, get_server_keys, get_server_version, ServerSigningKeys, get_remote_server_keys, get_remote_server_keys_batch,
VerifyKey, get_remote_server_keys_batch::v2::QueryCriteria, get_server_keys,
get_server_version, ServerSigningKeys, VerifyKey,
}, },
event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, event::{get_event, get_missing_events, get_room_state, get_room_state_ids},
keys::{claim_keys, get_keys}, keys::{claim_keys, get_keys},
@ -35,6 +36,7 @@ use ruma::{
}, },
directory::{IncomingFilter, IncomingRoomNetwork}, directory::{IncomingFilter, IncomingRoomNetwork},
events::{ events::{
pdu::Pdu,
receipt::{ReceiptEvent, ReceiptEventContent}, receipt::{ReceiptEvent, ReceiptEventContent},
room::{ room::{
create::CreateEventContent, create::CreateEventContent,
@ -3277,6 +3279,204 @@ pub(crate) async fn fetch_required_signing_keys(
Ok(()) Ok(())
} }
pub fn get_missing_signing_keys_for_pdus(
pdus: &Vec<Raw<Pdu>>,
servers: &mut BTreeMap<Box<ServerName>, BTreeMap<ServerSigningKeyId, QueryCriteria>>,
room_version: &RoomVersionId,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>,
db: &Database,
) -> Result<()> {
for pdu in pdus {
let value = serde_json::from_str::<CanonicalJsonObject>(pdu.json().get()).map_err(|e| {
error!("Invalid PDU in server response: {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response")
})?;
let event_id = EventId::try_from(&*format!(
"${}",
ruma::signatures::reference_hash(&value, &room_version)
.expect("ruma can calculate reference hashes")
))
.expect("ruma's reference hashes are valid event ids");
if let Some((time, tries)) = db
.globals
.bad_event_ratelimiter
.read()
.unwrap()
.get(&event_id)
{
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
debug!("Backing off from {}", event_id);
return Err(Error::BadServerResponse("bad event, still backing off"));
}
}
let signatures = value
.get("signatures")
.ok_or(Error::BadServerResponse(
"No signatures in server response pdu.",
))?
.as_object()
.ok_or(Error::BadServerResponse(
"Invalid signatures object in server response pdu.",
))?;
for (signature_server, signature) in signatures {
let signature_object = signature.as_object().ok_or(Error::BadServerResponse(
"Invalid signatures content object in server response pdu.",
))?;
let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>();
let contains_all_ids = |keys: &BTreeMap<String, String>| {
signature_ids.iter().all(|id| keys.contains_key(id))
};
let origin = &Box::<ServerName>::try_from(&**signature_server).map_err(|_| {
Error::BadServerResponse("Invalid servername in signatures of server response pdu.")
})?;
trace!("Loading signing keys for {}", origin);
let result = db
.globals
.signing_keys_for(origin)?
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect::<BTreeMap<_, _>>();
if !contains_all_ids(&result) {
trace!("Signing key not loaded for {}", origin);
servers.insert(
origin.clone(),
BTreeMap::<ServerSigningKeyId, QueryCriteria>::new(),
);
}
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(origin.to_string(), result);
}
}
Ok(())
}
pub async fn fetch_join_signing_keys(
event: &create_join_event::v2::Response,
room_version: &RoomVersionId,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, String>>>,
db: &Database,
) -> Result<()> {
let mut servers =
BTreeMap::<Box<ServerName>, BTreeMap<ServerSigningKeyId, QueryCriteria>>::new();
get_missing_signing_keys_for_pdus(
&event.room_state.state,
&mut servers,
&room_version,
&pub_key_map,
&db,
)?;
get_missing_signing_keys_for_pdus(
&event.room_state.auth_chain,
&mut servers,
&room_version,
&pub_key_map,
&db,
)?;
if servers.is_empty() {
return Ok(());
}
for server in db.globals.trusted_servers() {
if db.globals.signing_keys_for(server)?.is_empty() {
servers.insert(
server.clone(),
BTreeMap::<ServerSigningKeyId, QueryCriteria>::new(),
);
}
}
for server in db.globals.trusted_servers() {
trace!("Asking batch signing keys from trusted server {}", server);
if let Ok(keys) = db
.sending
.send_federation_request(
&db.globals,
server,
get_remote_server_keys_batch::v2::Request {
server_keys: servers.clone(),
minimum_valid_until_ts: MilliSecondsSinceUnixEpoch::from_system_time(
SystemTime::now() + Duration::from_secs(60),
)
.expect("time is valid"),
},
)
.await
{
trace!("Got signing keys: {:?}", keys);
for k in keys.server_keys {
// TODO: Check signature
servers.remove(&k.server_name);
db.globals.add_signing_key(&k.server_name, k.clone())?;
let result = db
.globals
.signing_keys_for(&k.server_name)?
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect::<BTreeMap<_, _>>();
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(k.server_name.to_string(), result);
}
}
if servers.is_empty() {
return Ok(());
}
}
for result in futures::future::join_all(servers.iter().map(|(server, _)| {
db.sending
.send_federation_request(&db.globals, server, get_server_keys::v2::Request::new())
}))
.await
{
if let Ok(get_keys_response) = result {
// TODO: We should probably not trust the server_name in the response.
let server = &get_keys_response.server_key.server_name;
db.globals
.add_signing_key(server, get_keys_response.server_key.clone())?;
let result = db
.globals
.signing_keys_for(server)?
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect::<BTreeMap<_, _>>();
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(server.to_string(), result);
}
}
Ok(())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{add_port_to_hostname, get_ip_with_port, FedDest}; use super::{add_port_to_hostname, get_ip_with_port, FedDest};