fix: improve key fetching

This commit is contained in:
Timo Kösters 2021-08-26 18:59:10 +02:00 committed by Jonas Zohren
parent f22ad5dfba
commit 3a588c4561

View file

@ -1,5 +1,6 @@
use super::SESSION_ID_LENGTH; use super::SESSION_ID_LENGTH;
use crate::{database::DatabaseGuard, utils, ConduitResult, Database, Error, Result, Ruma}; use crate::{database::DatabaseGuard, utils, ConduitResult, Database, Error, Result, Ruma};
use rocket::futures::{prelude::*, stream::FuturesUnordered};
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
@ -18,7 +19,10 @@ use ruma::{
DeviceId, DeviceKeyAlgorithm, UserId, DeviceId, DeviceKeyAlgorithm, UserId,
}; };
use serde_json::json; use serde_json::json;
use std::collections::{BTreeMap, HashSet}; use std::{
collections::{BTreeMap, HashMap, HashSet},
time::{Duration, Instant},
};
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::{get, post}; use rocket::{get, post};
@ -294,7 +298,7 @@ pub async fn get_keys_helper<F: Fn(&UserId) -> bool>(
let mut user_signing_keys = BTreeMap::new(); let mut user_signing_keys = BTreeMap::new();
let mut device_keys = BTreeMap::new(); let mut device_keys = BTreeMap::new();
let mut get_over_federation = BTreeMap::new(); let mut get_over_federation = HashMap::new();
for (user_id, device_ids) in device_keys_input { for (user_id, device_ids) in device_keys_input {
if user_id.server_name() != db.globals.server_name() { if user_id.server_name() != db.globals.server_name() {
@ -364,22 +368,30 @@ pub async fn get_keys_helper<F: Fn(&UserId) -> bool>(
let mut failures = BTreeMap::new(); let mut failures = BTreeMap::new();
for (server, vec) in get_over_federation { let mut futures = get_over_federation
let mut device_keys_input_fed = BTreeMap::new(); .into_iter()
for (user_id, keys) in vec { .map(|(server, vec)| async move {
device_keys_input_fed.insert(user_id.clone(), keys.clone()); let mut device_keys_input_fed = BTreeMap::new();
} for (user_id, keys) in vec {
match db device_keys_input_fed.insert(user_id.clone(), keys.clone());
.sending }
.send_federation_request( (
&db.globals,
server, server,
federation::keys::get_keys::v1::Request { db.sending
device_keys: device_keys_input_fed, .send_federation_request(
}, &db.globals,
server,
federation::keys::get_keys::v1::Request {
device_keys: device_keys_input_fed,
},
)
.await,
) )
.await })
{ .collect::<FuturesUnordered<_>>();
while let Some((server, response)) = futures.next().await {
match response {
Ok(response) => { Ok(response) => {
master_keys.extend(response.master_keys); master_keys.extend(response.master_keys);
self_signing_keys.extend(response.self_signing_keys); self_signing_keys.extend(response.self_signing_keys);
@ -430,13 +442,15 @@ pub async fn claim_keys_helper(
one_time_keys.insert(user_id.clone(), container); one_time_keys.insert(user_id.clone(), container);
} }
let mut failures = BTreeMap::new();
for (server, vec) in get_over_federation { for (server, vec) in get_over_federation {
let mut one_time_keys_input_fed = BTreeMap::new(); let mut one_time_keys_input_fed = BTreeMap::new();
for (user_id, keys) in vec { for (user_id, keys) in vec {
one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); one_time_keys_input_fed.insert(user_id.clone(), keys.clone());
} }
// Ignore failures // Ignore failures
let keys = db if let Ok(keys) = db
.sending .sending
.send_federation_request( .send_federation_request(
&db.globals, &db.globals,
@ -445,13 +459,16 @@ pub async fn claim_keys_helper(
one_time_keys: one_time_keys_input_fed, one_time_keys: one_time_keys_input_fed,
}, },
) )
.await?; .await
{
one_time_keys.extend(keys.one_time_keys); one_time_keys.extend(keys.one_time_keys);
} else {
failures.insert(server.to_string(), json!({}));
}
} }
Ok(claim_keys::Response { Ok(claim_keys::Response {
failures: BTreeMap::new(), failures,
one_time_keys, one_time_keys,
}) })
} }