pub mod abstraction; pub mod account_data; pub mod admin; pub mod appservice; pub mod globals; pub mod key_backups; pub mod media; pub mod pusher; pub mod rooms; pub mod sending; pub mod transaction_ids; pub mod uiaa; pub mod users; use crate::{utils, Error, Result}; use abstraction::DatabaseEngine; use directories::ProjectDirs; use log::error; use rocket::futures::{channel::mpsc, stream::FuturesUnordered, StreamExt}; use ruma::{DeviceId, ServerName, UserId}; use serde::Deserialize; use std::{ collections::HashMap, fs::{self, remove_dir_all}, io::Write, sync::{Arc, RwLock}, }; use tokio::sync::Semaphore; #[derive(Clone, Debug, Deserialize)] pub struct Config { server_name: Box, database_path: String, #[serde(default = "default_cache_capacity")] cache_capacity: u32, #[serde(default = "default_max_request_size")] max_request_size: u32, #[serde(default = "default_max_concurrent_requests")] max_concurrent_requests: u16, #[serde(default = "true_fn")] allow_registration: bool, #[serde(default = "true_fn")] allow_encryption: bool, #[serde(default = "false_fn")] allow_federation: bool, #[serde(default = "false_fn")] pub allow_jaeger: bool, #[serde(default)] proxy: ProxyConfig, jwt_secret: Option, #[serde(default = "Vec::new")] trusted_servers: Vec>, #[serde(default = "default_log")] pub log: String, } fn false_fn() -> bool { false } fn true_fn() -> bool { true } fn default_cache_capacity() -> u32 { 1024 * 1024 * 1024 } fn default_max_request_size() -> u32 { 20 * 1024 * 1024 // Default to 20 MB } fn default_max_concurrent_requests() -> u16 { 100 } fn default_log() -> String { "info,state_res=warn,rocket=off,_=off,sled=off".to_owned() } #[cfg(feature = "sled")] pub type Engine = abstraction::SledEngine; #[cfg(feature = "rocksdb")] pub type Engine = abstraction::RocksDbEngine; #[derive(Clone, Debug, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ProxyConfig { None, Global { #[serde(deserialize_with = "crate::utils::deserialize_from_str")] url: reqwest::Url, }, ByDomain(Vec), } impl ProxyConfig { pub fn to_proxy(&self) -> Result> { Ok(match self.clone() { ProxyConfig::None => None, ProxyConfig::Global { url } => Some(reqwest::Proxy::all(url)?), ProxyConfig::ByDomain(proxies) => Some(reqwest::Proxy::custom(move |url| { proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy })), }) } } impl Default for ProxyConfig { fn default() -> Self { ProxyConfig::None } } #[derive(Clone, Debug, Deserialize)] pub struct PartialProxyConfig { #[serde(deserialize_with = "crate::utils::deserialize_from_str")] url: reqwest::Url, #[serde(default)] include: Vec, #[serde(default)] exclude: Vec, } impl PartialProxyConfig { pub fn for_url(&self, url: &reqwest::Url) -> Option<&reqwest::Url> { let domain = url.domain()?; let mut included_because = None; // most specific reason it was included let mut excluded_because = None; // most specific reason it was excluded if self.include.is_empty() { // treat empty include list as `*` included_because = Some(&WildCardedDomain::WildCard) } for wc_domain in &self.include { if wc_domain.matches(domain) { match included_because { Some(prev) if !wc_domain.more_specific_than(prev) => (), _ => included_because = Some(wc_domain), } } } for wc_domain in &self.exclude { if wc_domain.matches(domain) { match excluded_because { Some(prev) if !wc_domain.more_specific_than(prev) => (), _ => excluded_because = Some(wc_domain), } } } match (included_because, excluded_because) { (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded (Some(_), None) => Some(&self.url), _ => None, } } } /// A domain name, that optionally allows a * as its first subdomain. #[derive(Clone, Debug)] pub enum WildCardedDomain { WildCard, WildCarded(String), Exact(String), } impl WildCardedDomain { pub fn matches(&self, domain: &str) -> bool { match self { WildCardedDomain::WildCard => true, WildCardedDomain::WildCarded(d) => domain.ends_with(d), WildCardedDomain::Exact(d) => domain == d, } } pub fn more_specific_than(&self, other: &Self) -> bool { match (self, other) { (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, (_, WildCardedDomain::WildCard) => true, (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { a != b && a.ends_with(b) } _ => false, } } } impl std::str::FromStr for WildCardedDomain { type Err = std::convert::Infallible; fn from_str(s: &str) -> std::result::Result { // maybe do some domain validation? Ok(if s.starts_with("*.") { WildCardedDomain::WildCarded(s[1..].to_owned()) } else if s == "*" { WildCardedDomain::WildCarded("".to_owned()) } else { WildCardedDomain::Exact(s.to_owned()) }) } } impl<'de> serde::de::Deserialize<'de> for WildCardedDomain { fn deserialize(deserializer: D) -> std::result::Result where D: serde::de::Deserializer<'de>, { crate::utils::deserialize_from_str(deserializer) } } #[derive(Clone)] pub struct Database { pub globals: globals::Globals, pub users: users::Users, pub uiaa: uiaa::Uiaa, pub rooms: rooms::Rooms, pub account_data: account_data::AccountData, pub media: media::Media, pub key_backups: key_backups::KeyBackups, pub transaction_ids: transaction_ids::TransactionIds, pub sending: sending::Sending, pub admin: admin::Admin, pub appservice: appservice::Appservice, pub pusher: pusher::PushData, } impl Database { /// Tries to remove the old database but ignores all errors. pub fn try_remove(server_name: &str) -> Result<()> { let mut path = ProjectDirs::from("xyz", "koesters", "conduit") .ok_or_else(|| Error::bad_config("The OS didn't return a valid home directory path."))? .data_dir() .to_path_buf(); path.push(server_name); let _ = remove_dir_all(path); Ok(()) } /// Load an existing database or create a new one. pub async fn load_or_create(config: Config) -> Result> { let builder = Engine::open(&config)?; if config.max_request_size < 1024 { eprintln!("ERROR: Max request size is less than 1KB. Please increase it."); } let (admin_sender, admin_receiver) = mpsc::unbounded(); let (sending_sender, sending_receiver) = mpsc::unbounded(); let db = Arc::new(Self { users: users::Users { userid_password: builder.open_tree("userid_password")?, userid_displayname: builder.open_tree("userid_displayname")?, userid_avatarurl: builder.open_tree("userid_avatarurl")?, userdeviceid_token: builder.open_tree("userdeviceid_token")?, userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, token_userdeviceid: builder.open_tree("token_userdeviceid")?, onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, keychangeid_userid: builder.open_tree("keychangeid_userid")?, keyid_key: builder.open_tree("keyid_key")?, userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, todeviceid_events: builder.open_tree("todeviceid_events")?, }, uiaa: uiaa::Uiaa { userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, userdevicesessionid_uiaarequest: builder .open_tree("userdevicesessionid_uiaarequest")?, }, rooms: rooms::Rooms { edus: rooms::RoomEdus { readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt roomuserid_lastprivatereadupdate: builder .open_tree("roomuserid_lastprivatereadupdate")?, typingid_userid: builder.open_tree("typingid_userid")?, roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?, presenceid_presence: builder.open_tree("presenceid_presence")?, userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?, }, pduid_pdu: builder.open_tree("pduid_pdu")?, eventid_pduid: builder.open_tree("eventid_pduid")?, roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, alias_roomid: builder.open_tree("alias_roomid")?, aliasid_alias: builder.open_tree("aliasid_alias")?, publicroomids: builder.open_tree("publicroomids")?, tokenids: builder.open_tree("tokenids")?, roomserverids: builder.open_tree("roomserverids")?, serverroomids: builder.open_tree("serverroomids")?, userroomid_joined: builder.open_tree("userroomid_joined")?, roomuserid_joined: builder.open_tree("roomuserid_joined")?, roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, stateid_shorteventid: builder.open_tree("stateid_shorteventid")?, eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, prevevent_parent: builder.open_tree("prevevent_parent")?, }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, }, media: media::Media { mediaid_file: builder.open_tree("mediaid_file")?, }, key_backups: key_backups::KeyBackups { backupid_algorithm: builder.open_tree("backupid_algorithm")?, backupid_etag: builder.open_tree("backupid_etag")?, backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, }, transaction_ids: transaction_ids::TransactionIds { userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, }, sending: sending::Sending { servername_educount: builder.open_tree("servername_educount")?, servernamepduids: builder.open_tree("servernamepduids")?, servercurrentevents: builder.open_tree("servercurrentevents")?, maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)), sender: sending_sender, }, admin: admin::Admin { sender: admin_sender, }, appservice: appservice::Appservice { cached_registrations: Arc::new(RwLock::new(HashMap::new())), id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, }, pusher: pusher::PushData { senderkey_pusher: builder.open_tree("senderkey_pusher")?, }, globals: globals::Globals::load( builder.open_tree("global")?, builder.open_tree("server_signingkeys")?, config, )?, }); // MIGRATIONS // TODO: database versions of new dbs should probably not be 0 if db.globals.database_version()? < 1 { for (roomserverid, _) in db.rooms.roomserverids.iter() { let mut parts = roomserverid.split(|&b| b == 0xff); let room_id = parts.next().expect("split always returns one element"); let servername = match parts.next() { Some(s) => s, None => { error!("Migration: Invalid roomserverid in db."); continue; } }; let mut serverroomid = servername.to_vec(); serverroomid.push(0xff); serverroomid.extend_from_slice(room_id); db.rooms.serverroomids.insert(&serverroomid, &[])?; } db.globals.bump_database_version(1)?; println!("Migration: 0 -> 1 finished"); } if db.globals.database_version()? < 2 { // We accidentally inserted hashed versions of "" into the db instead of just "" for (userid, password) in db.users.userid_password.iter() { let password = utils::string_from_bytes(&password); let empty_hashed_password = password.map_or(false, |password| { argon2::verify_encoded(&password, b"").unwrap_or(false) }); if empty_hashed_password { db.users.userid_password.insert(&userid, b"")?; } } db.globals.bump_database_version(2)?; println!("Migration: 1 -> 2 finished"); } if db.globals.database_version()? < 3 { // Move media to filesystem for (key, content) in db.media.mediaid_file.iter() { if content.len() == 0 { continue; } let path = db.globals.get_media_file(&key); let mut file = fs::File::create(path)?; file.write_all(&content)?; db.media.mediaid_file.insert(&key, &[])?; } db.globals.bump_database_version(3)?; println!("Migration: 2 -> 3 finished"); } if db.globals.database_version()? < 4 { // Add federated users to db as deactivated for our_user in db.users.iter() { let our_user = our_user?; if db.users.is_deactivated(&our_user)? { continue; } for room in db.rooms.rooms_joined(&our_user) { for user in db.rooms.room_members(&room?) { let user = user?; if user.server_name() != db.globals.server_name() { println!("Migration: Creating user {}", user); db.users.create(&user, None)?; } } } } db.globals.bump_database_version(4)?; println!("Migration: 3 -> 4 finished"); } // This data is probably outdated db.rooms.edus.presenceid_presence.clear()?; db.admin.start_handler(Arc::clone(&db), admin_receiver); db.sending.start_handler(Arc::clone(&db), sending_receiver); Ok(db) } pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) { let userid_bytes = user_id.as_bytes().to_vec(); let mut userid_prefix = userid_bytes.clone(); userid_prefix.push(0xff); let mut userdeviceid_prefix = userid_prefix.clone(); userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); userdeviceid_prefix.push(0xff); let mut futures = FuturesUnordered::new(); // Return when *any* user changed his key // TODO: only send for user they share a room with futures.push( self.users .todeviceid_events .watch_prefix(&userdeviceid_prefix), ); futures.push(self.rooms.userroomid_joined.watch_prefix(&userid_prefix)); futures.push( self.rooms .userroomid_invitestate .watch_prefix(&userid_prefix), ); futures.push(self.rooms.userroomid_leftstate.watch_prefix(&userid_prefix)); // Events for rooms we are in for room_id in self.rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { let roomid_bytes = room_id.as_bytes().to_vec(); let mut roomid_prefix = roomid_bytes.clone(); roomid_prefix.push(0xff); // PDUs futures.push(self.rooms.pduid_pdu.watch_prefix(&roomid_prefix)); // EDUs futures.push( self.rooms .edus .roomid_lasttypingupdate .watch_prefix(&roomid_bytes), ); futures.push( self.rooms .edus .readreceiptid_readreceipt .watch_prefix(&roomid_prefix), ); // Key changes futures.push(self.users.keychangeid_userid.watch_prefix(&roomid_prefix)); // Room account data let mut roomuser_prefix = roomid_prefix.clone(); roomuser_prefix.extend_from_slice(&userid_prefix); futures.push( self.account_data .roomuserdataid_accountdata .watch_prefix(&roomuser_prefix), ); } let mut globaluserdata_prefix = vec![0xff]; globaluserdata_prefix.extend_from_slice(&userid_prefix); futures.push( self.account_data .roomuserdataid_accountdata .watch_prefix(&globaluserdata_prefix), ); // More key changes (used when user is not joined to any rooms) futures.push(self.users.keychangeid_userid.watch_prefix(&userid_prefix)); // One time keys futures.push( self.users .userid_lastonetimekeyupdate .watch_prefix(&userid_bytes), ); // Wait until one of them finds something futures.next().await; } pub async fn flush(&self) -> Result<()> { // noop while we don't use sled 1.0 //self._db.flush_async().await?; Ok(()) } }