improvement: auth chain cache

This commit is contained in:
Timo Kösters 2021-07-18 20:43:39 +02:00
parent f5273f7eb1
commit cfaa900e83
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
10 changed files with 201 additions and 176 deletions

36
Cargo.lock generated
View file

@ -2015,7 +2015,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma" name = "ruma"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"assign", "assign",
"js_int", "js_int",
@ -2036,7 +2036,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-api" name = "ruma-api"
version = "0.17.1" version = "0.17.1"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"bytes", "bytes",
"http", "http",
@ -2052,7 +2052,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-api-macros" name = "ruma-api-macros"
version = "0.17.1" version = "0.17.1"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"proc-macro-crate", "proc-macro-crate",
"proc-macro2", "proc-macro2",
@ -2063,7 +2063,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-appservice-api" name = "ruma-appservice-api"
version = "0.3.0" version = "0.3.0"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"ruma-api", "ruma-api",
"ruma-common", "ruma-common",
@ -2077,7 +2077,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-client-api" name = "ruma-client-api"
version = "0.11.0" version = "0.11.0"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"assign", "assign",
"bytes", "bytes",
@ -2097,7 +2097,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-common" name = "ruma-common"
version = "0.5.4" version = "0.5.4"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"indexmap", "indexmap",
"js_int", "js_int",
@ -2112,7 +2112,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-events" name = "ruma-events"
version = "0.23.2" version = "0.23.2"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"indoc", "indoc",
"js_int", "js_int",
@ -2128,7 +2128,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-events-macros" name = "ruma-events-macros"
version = "0.23.2" version = "0.23.2"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"proc-macro-crate", "proc-macro-crate",
"proc-macro2", "proc-macro2",
@ -2139,7 +2139,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-federation-api" name = "ruma-federation-api"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-api", "ruma-api",
@ -2154,7 +2154,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identifiers" name = "ruma-identifiers"
version = "0.19.4" version = "0.19.4"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"paste", "paste",
"rand 0.8.4", "rand 0.8.4",
@ -2168,7 +2168,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identifiers-macros" name = "ruma-identifiers-macros"
version = "0.19.4" version = "0.19.4"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"quote", "quote",
"ruma-identifiers-validation", "ruma-identifiers-validation",
@ -2178,12 +2178,12 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identifiers-validation" name = "ruma-identifiers-validation"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
[[package]] [[package]]
name = "ruma-identity-service-api" name = "ruma-identity-service-api"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-api", "ruma-api",
@ -2196,7 +2196,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-push-gateway-api" name = "ruma-push-gateway-api"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-api", "ruma-api",
@ -2211,7 +2211,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-serde" name = "ruma-serde"
version = "0.4.1" version = "0.4.1"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"bytes", "bytes",
"form_urlencoded", "form_urlencoded",
@ -2225,7 +2225,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-serde-macros" name = "ruma-serde-macros"
version = "0.4.1" version = "0.4.1"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"proc-macro-crate", "proc-macro-crate",
"proc-macro2", "proc-macro2",
@ -2236,7 +2236,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-signatures" name = "ruma-signatures"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"base64 0.13.0", "base64 0.13.0",
"ed25519-dalek", "ed25519-dalek",
@ -2253,7 +2253,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-state-res" name = "ruma-state-res"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/ruma/ruma?rev=c29c2b16ec114fa655e2b70bdd53c82e35859005#c29c2b16ec114fa655e2b70bdd53c82e35859005" source = "git+https://github.com/timokoesters/ruma?rev=a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386#a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386"
dependencies = [ dependencies = [
"itertools 0.10.1", "itertools 0.10.1",
"js_int", "js_int",

View file

@ -18,7 +18,8 @@ edition = "2018"
rocket = { version = "0.5.0-rc.1", features = ["tls"] } # Used to handle requests rocket = { version = "0.5.0-rc.1", features = ["tls"] } # Used to handle requests
# Used for matrix spec type definitions and helpers # Used for matrix spec type definitions and helpers
ruma = { git = "https://github.com/ruma/ruma", rev = "c29c2b16ec114fa655e2b70bdd53c82e35859005", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } #ruma = { git = "https://github.com/ruma/ruma", rev = "c29c2b16ec114fa655e2b70bdd53c82e35859005", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] }
ruma = { git = "https://github.com/timokoesters/ruma", rev = "a3fd405d6b331c7bc4c6f366bc1b6ec303b3a386", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] }
#ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } #ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] }
# Used for long polling and federation sender, should be the same as rocket::tokio # Used for long polling and federation sender, should be the same as rocket::tokio
@ -119,5 +120,5 @@ maintainer-scripts = "debian/"
systemd-units = { unit-name = "matrix-conduit" } systemd-units = { unit-name = "matrix-conduit" }
# For flamegraphs: # For flamegraphs:
[profile.release] #[profile.release]
debug = true #debug = true

View file

@ -29,7 +29,7 @@ use ruma::{
uint, EventId, RoomId, RoomVersionId, ServerName, UserId, uint, EventId, RoomId, RoomVersionId, ServerName, UserId,
}; };
use std::{ use std::{
collections::{btree_map::Entry, BTreeMap, HashSet}, collections::{hash_map::Entry, BTreeMap, HashMap, HashSet},
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
sync::{Arc, RwLock}, sync::{Arc, RwLock},
time::{Duration, Instant}, time::{Duration, Instant},
@ -607,7 +607,7 @@ async fn join_room_by_id_helper(
let pdu = PduEvent::from_id_val(&event_id, join_event.clone()) let pdu = PduEvent::from_id_val(&event_id, join_event.clone())
.map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?;
let mut state = BTreeMap::new(); let mut state = HashMap::new();
let pub_key_map = RwLock::new(BTreeMap::new()); let pub_key_map = RwLock::new(BTreeMap::new());
for result in futures::future::join_all( for result in futures::future::join_all(

View file

@ -7,7 +7,7 @@ use ruma::{
DeviceId, RoomId, UserId, DeviceId, RoomId, UserId,
}; };
use std::{ use std::{
collections::{btree_map::Entry, hash_map, BTreeMap, HashMap, HashSet}, collections::{hash_map::Entry, BTreeMap, HashMap, HashSet},
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
sync::Arc, sync::Arc,
time::Duration, time::Duration,
@ -622,10 +622,10 @@ async fn sync_helper(
.presence_since(&room_id, since, &db.rooms, &db.globals)? .presence_since(&room_id, since, &db.rooms, &db.globals)?
{ {
match presence_updates.entry(user_id) { match presence_updates.entry(user_id) {
hash_map::Entry::Vacant(v) => { Entry::Vacant(v) => {
v.insert(presence); v.insert(presence);
} }
hash_map::Entry::Occupied(mut o) => { Entry::Occupied(mut o) => {
let p = o.get_mut(); let p = o.get_mut();
// Update existing presence event with more info // Update existing presence event with more info

View file

@ -33,7 +33,7 @@ use std::{
io::Write, io::Write,
ops::Deref, ops::Deref,
path::Path, path::Path,
sync::{Arc, RwLock}, sync::{Arc, Mutex, RwLock},
}; };
use tokio::sync::{OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore}; use tokio::sync::{OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore};
@ -292,7 +292,8 @@ impl Database {
eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
prevevent_parent: builder.open_tree("prevevent_parent")?, prevevent_parent: builder.open_tree("prevevent_parent")?,
pdu_cache: RwLock::new(LruCache::new(10_000)), pdu_cache: Mutex::new(LruCache::new(100_000)),
auth_chain_cache: Mutex::new(LruCache::new(100_000)),
}, },
account_data: account_data::AccountData { account_data: account_data::AccountData {
roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,

View file

@ -5,14 +5,14 @@ use std::{future::Future, pin::Pin, sync::Arc};
use super::{DatabaseEngine, Tree}; use super::{DatabaseEngine, Tree};
use std::{collections::BTreeMap, sync::RwLock}; use std::{collections::HashMap, sync::RwLock};
pub struct Engine(rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>); pub struct Engine(rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>);
pub struct RocksDbEngineTree<'a> { pub struct RocksDbEngineTree<'a> {
db: Arc<Engine>, db: Arc<Engine>,
name: &'a str, name: &'a str,
watchers: RwLock<BTreeMap<Vec<u8>, Vec<tokio::sync::oneshot::Sender<()>>>>, watchers: RwLock<HashMap<Vec<u8>, Vec<tokio::sync::oneshot::Sender<()>>>>,
} }
impl DatabaseEngine for Engine { impl DatabaseEngine for Engine {
@ -58,7 +58,7 @@ impl DatabaseEngine for Engine {
Ok(Arc::new(RocksDbEngineTree { Ok(Arc::new(RocksDbEngineTree {
name, name,
db: Arc::clone(self), db: Arc::clone(self),
watchers: RwLock::new(BTreeMap::new()), watchers: RwLock::new(HashMap::new()),
})) }))
} }
} }

View file

@ -7,7 +7,7 @@ use log::debug;
use parking_lot::{Mutex, MutexGuard, RwLock}; use parking_lot::{Mutex, MutexGuard, RwLock};
use rusqlite::{params, Connection, DatabaseName::Main, OptionalExtension}; use rusqlite::{params, Connection, DatabaseName::Main, OptionalExtension};
use std::{ use std::{
collections::BTreeMap, collections::HashMap,
future::Future, future::Future,
ops::Deref, ops::Deref,
path::{Path, PathBuf}, path::{Path, PathBuf},
@ -206,7 +206,7 @@ impl DatabaseEngine for Engine {
Ok(Arc::new(SqliteTable { Ok(Arc::new(SqliteTable {
engine: Arc::clone(self), engine: Arc::clone(self),
name: name.to_owned(), name: name.to_owned(),
watchers: RwLock::new(BTreeMap::new()), watchers: RwLock::new(HashMap::new()),
})) }))
} }
@ -266,7 +266,7 @@ impl Engine {
pub struct SqliteTable { pub struct SqliteTable {
engine: Arc<Engine>, engine: Arc<Engine>,
name: String, name: String,
watchers: RwLock<BTreeMap<Vec<u8>, Vec<Sender<()>>>>, watchers: RwLock<HashMap<Vec<u8>, Vec<Sender<()>>>>,
} }
type TupleOfBytes = (Vec<u8>, Vec<u8>); type TupleOfBytes = (Vec<u8>, Vec<u8>);

View file

@ -41,12 +41,12 @@ pub struct Globals {
dns_resolver: TokioAsyncResolver, dns_resolver: TokioAsyncResolver,
jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>, jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>,
pub(super) server_signingkeys: Arc<dyn Tree>, pub(super) server_signingkeys: Arc<dyn Tree>,
pub bad_event_ratelimiter: Arc<RwLock<BTreeMap<EventId, RateLimitState>>>, pub bad_event_ratelimiter: Arc<RwLock<HashMap<EventId, RateLimitState>>>,
pub bad_signature_ratelimiter: Arc<RwLock<BTreeMap<Vec<String>, RateLimitState>>>, pub bad_signature_ratelimiter: Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>,
pub servername_ratelimiter: Arc<RwLock<BTreeMap<Box<ServerName>, Arc<Semaphore>>>>, pub servername_ratelimiter: Arc<RwLock<HashMap<Box<ServerName>, Arc<Semaphore>>>>,
pub sync_receivers: RwLock<BTreeMap<(UserId, Box<DeviceId>), SyncHandle>>, pub sync_receivers: RwLock<HashMap<(UserId, Box<DeviceId>), SyncHandle>>,
pub roomid_mutex: RwLock<BTreeMap<RoomId, Arc<Mutex<()>>>>, pub roomid_mutex: RwLock<HashMap<RoomId, Arc<Mutex<()>>>>,
pub roomid_mutex_federation: RwLock<BTreeMap<RoomId, Arc<Mutex<()>>>>, // this lock will be held longer pub roomid_mutex_federation: RwLock<HashMap<RoomId, Arc<Mutex<()>>>>, // this lock will be held longer
pub rotate: RotationHandler, pub rotate: RotationHandler,
} }
@ -196,12 +196,12 @@ impl Globals {
tls_name_override, tls_name_override,
server_signingkeys, server_signingkeys,
jwt_decoding_key, jwt_decoding_key,
bad_event_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), bad_signature_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), servername_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
roomid_mutex: RwLock::new(BTreeMap::new()), roomid_mutex: RwLock::new(HashMap::new()),
roomid_mutex_federation: RwLock::new(BTreeMap::new()), roomid_mutex_federation: RwLock::new(HashMap::new()),
sync_receivers: RwLock::new(BTreeMap::new()), sync_receivers: RwLock::new(HashMap::new()),
rotate: RotationHandler::new(), rotate: RotationHandler::new(),
}; };

View file

@ -25,7 +25,7 @@ use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet}, collections::{BTreeMap, BTreeSet, HashMap, HashSet},
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
mem, mem,
sync::{Arc, RwLock}, sync::{Arc, Mutex},
}; };
use super::{abstraction::Tree, admin::AdminCommand, pusher}; use super::{abstraction::Tree, admin::AdminCommand, pusher};
@ -84,7 +84,8 @@ pub struct Rooms {
/// RoomId + EventId -> Parent PDU EventId. /// RoomId + EventId -> Parent PDU EventId.
pub(super) prevevent_parent: Arc<dyn Tree>, pub(super) prevevent_parent: Arc<dyn Tree>,
pub(super) pdu_cache: RwLock<LruCache<EventId, Arc<PduEvent>>>, pub(super) pdu_cache: Mutex<LruCache<EventId, Arc<PduEvent>>>,
pub(super) auth_chain_cache: Mutex<LruCache<EventId, HashSet<EventId>>>,
} }
impl Rooms { impl Rooms {
@ -109,7 +110,7 @@ impl Rooms {
pub fn state_full( pub fn state_full(
&self, &self,
shortstatehash: u64, shortstatehash: u64,
) -> Result<BTreeMap<(EventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(EventType, String), Arc<PduEvent>>> {
let state = self let state = self
.stateid_shorteventid .stateid_shorteventid
.scan_prefix(shortstatehash.to_be_bytes().to_vec()) .scan_prefix(shortstatehash.to_be_bytes().to_vec())
@ -282,7 +283,7 @@ impl Rooms {
pub fn force_state( pub fn force_state(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
state: BTreeMap<(EventType, String), EventId>, state: HashMap<(EventType, String), EventId>,
db: &Database, db: &Database,
) -> Result<()> { ) -> Result<()> {
let state_hash = self.calculate_hash( let state_hash = self.calculate_hash(
@ -402,11 +403,11 @@ impl Rooms {
pub fn room_state_full( pub fn room_state_full(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
) -> Result<BTreeMap<(EventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(EventType, String), Arc<PduEvent>>> {
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? {
self.state_full(current_shortstatehash) self.state_full(current_shortstatehash)
} else { } else {
Ok(BTreeMap::new()) Ok(HashMap::new())
} }
} }
@ -542,7 +543,7 @@ impl Rooms {
/// ///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline. /// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> { pub fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(p) = self.pdu_cache.write().unwrap().get_mut(&event_id) { if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(&event_id) {
return Ok(Some(Arc::clone(p))); return Ok(Some(Arc::clone(p)));
} }
@ -568,7 +569,7 @@ impl Rooms {
.transpose()? .transpose()?
{ {
self.pdu_cache self.pdu_cache
.write() .lock()
.unwrap() .unwrap()
.insert(event_id.clone(), Arc::clone(&pdu)); .insert(event_id.clone(), Arc::clone(&pdu));
Ok(Some(pdu)) Ok(Some(pdu))
@ -2520,4 +2521,10 @@ impl Rooms {
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
} }
pub fn auth_chain_cache(
&self,
) -> std::sync::MutexGuard<'_, LruCache<EventId, HashSet<EventId>>> {
self.auth_chain_cache.lock().unwrap()
}
} }

View file

@ -6,6 +6,7 @@ use crate::{
use get_profile_information::v1::ProfileField; use get_profile_information::v1::ProfileField;
use http::header::{HeaderValue, AUTHORIZATION, HOST}; use http::header::{HeaderValue, AUTHORIZATION, HOST};
use log::{debug, error, info, trace, warn}; use log::{debug, error, info, trace, warn};
use lru_cache::LruCache;
use regex::Regex; use regex::Regex;
use rocket::response::content::Json; use rocket::response::content::Json;
use ruma::{ use ruma::{
@ -52,7 +53,7 @@ use ruma::{
ServerSigningKeyId, UserId, ServerSigningKeyId, UserId,
}; };
use std::{ use std::{
collections::{btree_map::Entry, BTreeMap, BTreeSet, HashSet}, collections::{hash_map::Entry, BTreeMap, HashMap, HashSet},
convert::{TryFrom, TryInto}, convert::{TryFrom, TryInto},
fmt::Debug, fmt::Debug,
future::Future, future::Future,
@ -931,7 +932,7 @@ pub fn handle_incoming_pdu<'a>(
); );
// Build map of auth events // Build map of auth events
let mut auth_events = BTreeMap::new(); let mut auth_events = HashMap::new();
for id in &incoming_pdu.auth_events { for id in &incoming_pdu.auth_events {
let auth_event = db let auth_event = db
.rooms .rooms
@ -1097,7 +1098,7 @@ pub fn handle_incoming_pdu<'a>(
Err(_) => return Err("Failed to fetch state events.".to_owned()), Err(_) => return Err("Failed to fetch state events.".to_owned()),
}; };
let mut state = BTreeMap::new(); let mut state = HashMap::new();
for pdu in state_vec { for pdu in state_vec {
match state.entry((pdu.kind.clone(), pdu.state_key.clone().ok_or_else(|| "Found non-state pdu in state events.".to_owned())?)) { match state.entry((pdu.kind.clone(), pdu.state_key.clone().ok_or_else(|| "Found non-state pdu in state events.".to_owned())?)) {
Entry::Vacant(v) => { Entry::Vacant(v) => {
@ -1173,7 +1174,8 @@ pub fn handle_incoming_pdu<'a>(
} }
} }
let mut fork_states = BTreeSet::new(); let mut extremity_statehashes = Vec::new();
for id in &extremities { for id in &extremities {
match db match db
.rooms .rooms
@ -1181,8 +1183,8 @@ pub fn handle_incoming_pdu<'a>(
.map_err(|_| "Failed to ask db for pdu.".to_owned())? .map_err(|_| "Failed to ask db for pdu.".to_owned())?
{ {
Some(leaf_pdu) => { Some(leaf_pdu) => {
let pdu_shortstatehash = db extremity_statehashes.push((
.rooms db.rooms
.pdu_shortstatehash(&leaf_pdu.event_id) .pdu_shortstatehash(&leaf_pdu.event_id)
.map_err(|_| "Failed to ask db for pdu state hash.".to_owned())? .map_err(|_| "Failed to ask db for pdu state hash.".to_owned())?
.ok_or_else(|| { .ok_or_else(|| {
@ -1191,20 +1193,9 @@ pub fn handle_incoming_pdu<'a>(
leaf_pdu leaf_pdu
); );
"Found pdu with no statehash in db.".to_owned() "Found pdu with no statehash in db.".to_owned()
})?; })?,
Some(leaf_pdu),
let mut leaf_state = db ));
.rooms
.state_full(pdu_shortstatehash)
.map_err(|_| "Failed to ask db for room state.".to_owned())?;
if let Some(state_key) = &leaf_pdu.state_key {
// Now it's the state after
let key = (leaf_pdu.kind.clone(), state_key.clone());
leaf_state.insert(key, leaf_pdu);
}
fork_states.insert(leaf_state);
} }
_ => { _ => {
error!("Missing state snapshot for {:?}", id); error!("Missing state snapshot for {:?}", id);
@ -1218,12 +1209,36 @@ pub fn handle_incoming_pdu<'a>(
// don't just trust a set of state we got from a remote). // don't just trust a set of state we got from a remote).
// We do this by adding the current state to the list of fork states // We do this by adding the current state to the list of fork states
let current_statehash = db
.rooms
.current_shortstatehash(&room_id)
.map_err(|_| "Failed to load current state hash.".to_owned())?
.expect("every room has state");
let current_state = db let current_state = db
.rooms .rooms
.room_state_full(&room_id) .state_full(current_statehash)
.map_err(|_| "Failed to load room state.".to_owned())?; .map_err(|_| "Failed to load room state.")?;
fork_states.insert(current_state.clone()); extremity_statehashes.push((current_statehash.clone(), None));
let mut fork_states = Vec::new();
for (statehash, leaf_pdu) in extremity_statehashes {
let mut leaf_state = db
.rooms
.state_full(statehash)
.map_err(|_| "Failed to ask db for room state.".to_owned())?;
if let Some(leaf_pdu) = leaf_pdu {
if let Some(state_key) = &leaf_pdu.state_key {
// Now it's the state after
let key = (leaf_pdu.kind.clone(), state_key.clone());
leaf_state.insert(key, leaf_pdu);
}
}
fork_states.push(leaf_state);
}
// We also add state after incoming event to the fork states // We also add state after incoming event to the fork states
extremities.insert(incoming_pdu.event_id.clone()); extremities.insert(incoming_pdu.event_id.clone());
@ -1234,9 +1249,7 @@ pub fn handle_incoming_pdu<'a>(
incoming_pdu.clone(), incoming_pdu.clone(),
); );
} }
fork_states.insert(state_after.clone()); fork_states.push(state_after.clone());
let fork_states = fork_states.into_iter().collect::<Vec<_>>();
let mut update_state = false; let mut update_state = false;
// 14. Use state resolution to find new room state // 14. Use state resolution to find new room state
@ -1254,17 +1267,31 @@ pub fn handle_incoming_pdu<'a>(
// We do need to force an update to this room's state // We do need to force an update to this room's state
update_state = true; update_state = true;
match state_res::StateResolution::resolve( let fork_states = &fork_states
&room_id,
room_version_id,
&fork_states
.into_iter() .into_iter()
.map(|map| { .map(|map| {
map.into_iter() map.into_iter()
.map(|(k, v)| (k, v.event_id.clone())) .map(|(k, v)| (k, v.event_id.clone()))
.collect::<StateMap<_>>() .collect::<StateMap<_>>()
}) })
.collect::<Vec<_>>(), .collect::<Vec<_>>();
let auth_chain_t = Instant::now();
let mut auth_chain_sets = Vec::new();
for state in fork_states {
auth_chain_sets.push(
get_auth_chain(state.iter().map(|(_, id)| id.clone()).collect(), db)
.map_err(|_| "Failed to load auth chain.".to_owned())?,
);
}
dbg!(auth_chain_t.elapsed());
let state_res_t = Instant::now();
let state = match state_res::StateResolution::resolve(
&room_id,
room_version_id,
fork_states,
auth_chain_sets,
|id| { |id| {
let res = db.rooms.get_pdu(id); let res = db.rooms.get_pdu(id);
if let Err(e) = &res { if let Err(e) = &res {
@ -1277,7 +1304,9 @@ pub fn handle_incoming_pdu<'a>(
Err(_) => { Err(_) => {
return Err("State resolution failed, either an event could not be found or deserialization".into()); return Err("State resolution failed, either an event could not be found or deserialization".into());
} }
} };
dbg!(state_res_t.elapsed());
state
}; };
// 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it
@ -1696,6 +1725,42 @@ async fn append_incoming_pdu(
Ok(pdu_id) Ok(pdu_id)
} }
fn get_auth_chain(starting_events: Vec<EventId>, db: &Database) -> Result<HashSet<EventId>> {
let mut auth_chain_cache = db.rooms.auth_chain_cache();
let mut auth_chain = HashSet::new();
for event in starting_events {
auth_chain.extend(get_auth_chain_recursive(&event, &mut auth_chain_cache, db)?);
}
Ok(auth_chain)
}
fn get_auth_chain_recursive(
event_id: &EventId,
auth_chain_cache: &mut std::sync::MutexGuard<'_, LruCache<EventId, HashSet<EventId>>>,
db: &Database,
) -> Result<HashSet<EventId>> {
if let Some(cached) = auth_chain_cache.get_mut(event_id) {
return Ok(cached.clone());
}
let mut auth_chain = HashSet::new();
if let Some(pdu) = db.rooms.get_pdu(&event_id)? {
for auth_event in &pdu.auth_events {
auth_chain.extend(get_auth_chain_recursive(&auth_event, auth_chain_cache, db)?);
}
} else {
warn!("Could not find pdu mentioned in auth events.");
}
auth_chain_cache.insert(event_id.clone(), auth_chain.clone());
Ok(auth_chain)
}
#[cfg_attr( #[cfg_attr(
feature = "conduit_bin", feature = "conduit_bin",
get("/_matrix/federation/v1/event/<_>", data = "<body>") get("/_matrix/federation/v1/event/<_>", data = "<body>")
@ -1783,35 +1848,20 @@ pub fn get_event_authorization_route(
return Err(Error::bad_config("Federation is disabled.")); return Err(Error::bad_config("Federation is disabled."));
} }
let mut auth_chain = Vec::new(); let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?;
let mut auth_chain_ids = BTreeSet::<EventId>::new();
let mut todo = BTreeSet::new();
todo.insert(body.event_id.clone());
while let Some(event_id) = todo.iter().next().cloned() { Ok(get_event_authorization::v1::Response {
if let Some(pdu) = db.rooms.get_pdu(&event_id)? { auth_chain: auth_chain_ids
todo.extend(
pdu.auth_events
.clone()
.into_iter() .into_iter()
.collect::<BTreeSet<_>>() .map(|id| {
.difference(&auth_chain_ids) Ok::<_, Error>(PduEvent::convert_to_outgoing_federation_event(
.cloned(), db.rooms.get_pdu_json(&id)?.unwrap(),
); ))
auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); })
.filter_map(|r| r.ok())
let pdu_json = PduEvent::convert_to_outgoing_federation_event( .collect(),
db.rooms.get_pdu_json(&event_id)?.unwrap(),
);
auth_chain.push(pdu_json);
} else {
warn!("Could not find pdu mentioned in auth events.");
} }
.into())
todo.remove(&event_id);
}
Ok(get_event_authorization::v1::Response { auth_chain }.into())
} }
#[cfg_attr( #[cfg_attr(
@ -1846,35 +1896,21 @@ pub fn get_room_state_route(
}) })
.collect(); .collect();
let mut auth_chain = Vec::new(); let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?;
let mut auth_chain_ids = BTreeSet::<EventId>::new();
let mut todo = BTreeSet::new();
todo.insert(body.event_id.clone());
while let Some(event_id) = todo.iter().next().cloned() { Ok(get_room_state::v1::Response {
if let Some(pdu) = db.rooms.get_pdu(&event_id)? { auth_chain: auth_chain_ids
todo.extend(
pdu.auth_events
.clone()
.into_iter() .into_iter()
.collect::<BTreeSet<_>>() .map(|id| {
.difference(&auth_chain_ids) Ok::<_, Error>(PduEvent::convert_to_outgoing_federation_event(
.cloned(), db.rooms.get_pdu_json(&id)?.unwrap(),
); ))
auth_chain_ids.extend(pdu.auth_events.clone().into_iter()); })
.filter_map(|r| r.ok())
let pdu_json = PduEvent::convert_to_outgoing_federation_event( .collect(),
db.rooms.get_pdu_json(&event_id)?.unwrap(), pdus,
);
auth_chain.push(pdu_json);
} else {
warn!("Could not find pdu mentioned in auth events.");
} }
.into())
todo.remove(&event_id);
}
Ok(get_room_state::v1::Response { auth_chain, pdus }.into())
} }
#[cfg_attr( #[cfg_attr(
@ -1904,27 +1940,7 @@ pub fn get_room_state_ids_route(
.into_iter() .into_iter()
.collect(); .collect();
let mut auth_chain_ids = BTreeSet::<EventId>::new(); let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?;
let mut todo = BTreeSet::new();
todo.insert(body.event_id.clone());
while let Some(event_id) = todo.iter().next().cloned() {
if let Some(pdu) = db.rooms.get_pdu(&event_id)? {
todo.extend(
pdu.auth_events
.clone()
.into_iter()
.collect::<BTreeSet<_>>()
.difference(&auth_chain_ids)
.cloned(),
);
auth_chain_ids.extend(pdu.auth_events.clone().into_iter());
} else {
warn!("Could not find pdu mentioned in auth events.");
}
todo.remove(&event_id);
}
Ok(get_room_state_ids::v1::Response { Ok(get_room_state_ids::v1::Response {
auth_chain_ids: auth_chain_ids.into_iter().collect(), auth_chain_ids: auth_chain_ids.into_iter().collect(),
@ -2182,8 +2198,8 @@ pub async fn create_join_event_route(
let state_ids = db.rooms.state_full_ids(shortstatehash)?; let state_ids = db.rooms.state_full_ids(shortstatehash)?;
let mut auth_chain_ids = BTreeSet::<EventId>::new(); let mut auth_chain_ids = HashSet::<EventId>::new();
let mut todo = state_ids.iter().cloned().collect::<BTreeSet<_>>(); let mut todo = state_ids.iter().cloned().collect::<HashSet<_>>();
while let Some(event_id) = todo.iter().next().cloned() { while let Some(event_id) = todo.iter().next().cloned() {
if let Some(pdu) = db.rooms.get_pdu(&event_id)? { if let Some(pdu) = db.rooms.get_pdu(&event_id)? {
@ -2191,7 +2207,7 @@ pub async fn create_join_event_route(
pdu.auth_events pdu.auth_events
.clone() .clone()
.into_iter() .into_iter()
.collect::<BTreeSet<_>>() .collect::<HashSet<_>>()
.difference(&auth_chain_ids) .difference(&auth_chain_ids)
.cloned(), .cloned(),
); );