Use sled::Tree::prefix_search for deviceids

This commit is contained in:
timokoesters 2020-03-30 13:46:18 +02:00
parent b508b4d1e7
commit dba6c46667
No known key found for this signature in database
GPG key ID: 356E705610F626D5
5 changed files with 207 additions and 102 deletions

1
rustfmt.toml Normal file
View file

@ -0,0 +1 @@
merge_imports = true

View file

@ -1,134 +1,115 @@
use crate::utils; use crate::{utils, Database};
use directories::ProjectDirs;
use log::debug;
use ruma_events::collections::all::Event; use ruma_events::collections::all::Event;
use ruma_identifiers::{EventId, RoomId, UserId}; use ruma_identifiers::{EventId, RoomId, UserId};
use std::convert::TryInto; use std::convert::TryInto;
const USERID_PASSWORD: &str = "userid_password"; pub struct Data {
const USERID_DEVICEIDS: &str = "userid_deviceids"; hostname: String,
const DEVICEID_TOKEN: &str = "deviceid_token"; db: Database,
const TOKEN_USERID: &str = "token_userid"; }
pub struct Data(sled::Db);
impl Data { impl Data {
/// Load an existing database or create a new one. /// Load an existing database or create a new one.
pub fn load_or_create() -> Self { pub fn load_or_create(hostname: &str) -> Self {
Data( Self {
sled::open( hostname: hostname.to_owned(),
ProjectDirs::from("xyz", "koesters", "matrixserver") db: Database::load_or_create(hostname),
.unwrap()
.data_dir(),
)
.unwrap(),
)
} }
/// Set the hostname of the server. Warning: Hostname changes will likely break things.
pub fn set_hostname(&self, hostname: &str) {
self.0.insert("hostname", hostname).unwrap();
} }
/// Get the hostname of the server. /// Get the hostname of the server.
pub fn hostname(&self) -> String { pub fn hostname(&self) -> &str {
utils::bytes_to_string(&self.0.get("hostname").unwrap().unwrap()) &self.hostname
} }
/// Check if a user has an account by looking for an assigned password. /// Check if a user has an account by looking for an assigned password.
pub fn user_exists(&self, user_id: &UserId) -> bool { pub fn user_exists(&self, user_id: &UserId) -> bool {
self.0 self.db
.open_tree(USERID_PASSWORD) .userid_password
.unwrap()
.contains_key(user_id.to_string()) .contains_key(user_id.to_string())
.unwrap() .unwrap()
} }
/// Create a new user account by assigning them a password. /// Create a new user account by assigning them a password.
pub fn user_add(&self, user_id: &UserId, password: Option<String>) { pub fn user_add(&self, user_id: &UserId, password: Option<String>) {
self.0 self.db
.open_tree(USERID_PASSWORD) .userid_password
.unwrap()
.insert(user_id.to_string(), &*password.unwrap_or_default()) .insert(user_id.to_string(), &*password.unwrap_or_default())
.unwrap(); .unwrap();
} }
/// Find out which user an access token belongs to. /// Find out which user an access token belongs to.
pub fn user_from_token(&self, token: &str) -> Option<UserId> { pub fn user_from_token(&self, token: &str) -> Option<UserId> {
self.0 self.db
.open_tree(TOKEN_USERID) .token_userid
.unwrap()
.get(token) .get(token)
.unwrap() .unwrap()
.and_then(|bytes| (*utils::bytes_to_string(&bytes)).try_into().ok()) .and_then(|bytes| (*utils::string_from_bytes(&bytes)).try_into().ok())
} }
/// Checks if the given password is equal to the one in the database. /// Checks if the given password is equal to the one in the database.
pub fn password_get(&self, user_id: &UserId) -> Option<String> { pub fn password_get(&self, user_id: &UserId) -> Option<String> {
self.0 self.db
.open_tree(USERID_PASSWORD) .userid_password
.unwrap()
.get(user_id.to_string()) .get(user_id.to_string())
.unwrap() .unwrap()
.map(|bytes| utils::bytes_to_string(&bytes)) .map(|bytes| utils::string_from_bytes(&bytes))
} }
/// Add a new device to a user. /// Add a new device to a user.
pub fn device_add(&self, user_id: &UserId, device_id: &str) { pub fn device_add(&self, user_id: &UserId, device_id: &str) {
self.0 if self
.open_tree(USERID_DEVICEIDS) .db
.unwrap() .userid_deviceids
.insert(user_id.to_string(), device_id) .get_iter(&user_id.to_string().as_bytes())
.unwrap(); .filter_map(|item| item.ok())
.map(|(_key, value)| value)
.all(|device| device != device_id)
{
self.db
.userid_deviceids
.add(user_id.to_string().as_bytes(), device_id.into());
}
} }
/// Replace the access token of one device. /// Replace the access token of one device.
pub fn token_replace(&self, user_id: &UserId, device_id: &String, token: String) { pub fn token_replace(&self, user_id: &UserId, device_id: &String, token: String) {
// Make sure the device id belongs to the user // Make sure the device id belongs to the user
debug_assert!(self debug_assert!(self
.0 .db
.open_tree(USERID_DEVICEIDS) .userid_deviceids
.unwrap() .get_iter(&user_id.to_string().as_bytes())
.get(&user_id.to_string()) // Does the user exist? .filter_map(|item| item.ok())
.unwrap() .map(|(_key, value)| value)
.map(|bytes| utils::bytes_to_vec(&bytes)) .any(|device| device == device_id.as_bytes())); // Does the user have that device?
.filter(|devices| devices.contains(device_id)) // Does the user have that device?
.is_some());
// Remove old token // Remove old token
if let Some(old_token) = self if let Some(old_token) = self.db.deviceid_token.get(device_id).unwrap() {
.0 self.db.token_userid.remove(old_token).unwrap();
.open_tree(DEVICEID_TOKEN) // It will be removed from deviceid_token by the insert later
.unwrap()
.get(device_id)
.unwrap()
{
self.0
.open_tree(TOKEN_USERID)
.unwrap()
.remove(old_token)
.unwrap();
// It will be removed from DEVICEID_TOKEN by the insert later
} }
// Assign token to device_id // Assign token to device_id
self.0 self.db.deviceid_token.insert(device_id, &*token).unwrap();
.open_tree(DEVICEID_TOKEN)
.unwrap()
.insert(device_id, &*token)
.unwrap();
// Assign token to user // Assign token to user
self.0 self.db
.open_tree(TOKEN_USERID) .token_userid
.unwrap()
.insert(token, &*user_id.to_string()) .insert(token, &*user_id.to_string())
.unwrap(); .unwrap();
} }
/// Create a new room event. /// Create a new room event.
pub fn event_add(&self, event: &Event, room_id: &RoomId, event_id: &EventId) { pub fn event_add(&self, room_id: &RoomId, event_id: &EventId, event: &Event) {
debug!("{}", serde_json::to_string(event).unwrap()); let mut key = room_id.to_string().as_bytes().to_vec();
todo!(); key.extend_from_slice(event_id.to_string().as_bytes());
self.db
.roomid_eventid_event
.insert(&key, &*serde_json::to_string(event).unwrap())
.unwrap();
}
pub fn debug(&self) {
self.db.debug();
} }
} }

117
src/database.rs Normal file
View file

@ -0,0 +1,117 @@
use crate::utils;
use directories::ProjectDirs;
use sled::IVec;
pub struct MultiValue(sled::Tree);
impl MultiValue {
/// Get an iterator over all values.
pub fn iter_all(&self) -> sled::Iter {
self.0.iter()
}
/// Get an iterator over all values of this id.
pub fn get_iter(&self, id: &[u8]) -> sled::Iter {
// Data keys start with d
let mut key = vec![b'd'];
key.extend_from_slice(id.as_ref());
key.push(0xff); // Add delimiter so we don't find usernames starting with the same id
self.0.scan_prefix(key)
}
/// Add another value to the id.
pub fn add(&self, id: &[u8], value: IVec) {
// The new value will need a new index. We store the last used index in 'n' + id
let mut count_key: Vec<u8> = vec![b'n'];
count_key.extend_from_slice(id.as_ref());
// Increment the last index and use that
let index = self
.0
.update_and_fetch(&count_key, utils::increment)
.unwrap()
.unwrap();
// Data keys start with d
let mut key = vec![b'd'];
key.extend_from_slice(id.as_ref());
key.push(0xff);
key.extend_from_slice(&index);
self.0.insert(key, value).unwrap();
}
}
pub struct Database {
pub userid_password: sled::Tree,
pub userid_deviceids: MultiValue,
pub deviceid_token: sled::Tree,
pub token_userid: sled::Tree,
pub roomid_eventid_event: sled::Tree,
_db: sled::Db,
}
impl Database {
/// Load an existing database or create a new one.
pub fn load_or_create(hostname: &str) -> Self {
let mut path = ProjectDirs::from("xyz", "koesters", "matrixserver")
.unwrap()
.data_dir()
.to_path_buf();
path.push(hostname);
let db = sled::open(&path).unwrap();
Self {
userid_password: db.open_tree("userid_password").unwrap(),
userid_deviceids: MultiValue(db.open_tree("userid_deviceids").unwrap()),
deviceid_token: db.open_tree("deviceid_token").unwrap(),
token_userid: db.open_tree("token_userid").unwrap(),
roomid_eventid_event: db.open_tree("roomid_eventid_event").unwrap(),
_db: db,
}
}
pub fn debug(&self) {
println!("# UserId -> Password:");
for (k, v) in self.userid_password.iter().map(|r| r.unwrap()) {
println!(
"{} -> {}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("# UserId -> DeviceIds:");
for (k, v) in self.userid_deviceids.iter_all().map(|r| r.unwrap()) {
println!(
"{} -> {}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("# DeviceId -> Token:");
for (k, v) in self.deviceid_token.iter().map(|r| r.unwrap()) {
println!(
"{} -> {}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("# Token -> UserId:");
for (k, v) in self.token_userid.iter().map(|r| r.unwrap()) {
println!(
"{} -> {}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("# RoomId + EventId -> Event:");
for (k, v) in self.roomid_eventid_event.iter().map(|r| r.unwrap()) {
println!(
"{} -> {}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
}
}

View file

@ -1,9 +1,12 @@
#![feature(proc_macro_hygiene, decl_macro)] #![feature(proc_macro_hygiene, decl_macro)]
mod data; mod data;
mod database;
mod ruma_wrapper; mod ruma_wrapper;
mod utils; mod utils;
pub use data::Data; pub use data::Data;
pub use database::Database;
use log::debug; use log::debug;
use rocket::{get, post, put, routes, State}; use rocket::{get, post, put, routes, State};
use ruma_client_api::{ use ruma_client_api::{
@ -14,13 +17,14 @@ use ruma_client_api::{
}, },
unversioned::get_supported_versions, unversioned::get_supported_versions,
}; };
use ruma_events::collections::all::Event; use ruma_events::{collections::all::Event, room::message::MessageEvent};
use ruma_events::room::message::MessageEvent;
use ruma_identifiers::{EventId, UserId}; use ruma_identifiers::{EventId, UserId};
use ruma_wrapper::{MatrixResult, Ruma}; use ruma_wrapper::{MatrixResult, Ruma};
use serde_json::map::Map; use serde_json::map::Map;
use std::convert::TryFrom; use std::{
use std::{collections::HashMap, convert::TryInto}; collections::HashMap,
convert::{TryFrom, TryInto},
};
#[get("/_matrix/client/versions")] #[get("/_matrix/client/versions")]
fn get_supported_versions_route() -> MatrixResult<get_supported_versions::Response> { fn get_supported_versions_route() -> MatrixResult<get_supported_versions::Response> {
@ -90,7 +94,7 @@ fn register_route(
MatrixResult(Ok(register::Response { MatrixResult(Ok(register::Response {
access_token: token, access_token: token,
home_server: data.hostname(), home_server: data.hostname().to_owned(),
user_id, user_id,
device_id, device_id,
})) }))
@ -153,7 +157,7 @@ fn login_route(data: State<Data>, body: Ruma<login::Request>) -> MatrixResult<lo
.clone() .clone()
.unwrap_or("TODO:randomdeviceid".to_owned()); .unwrap_or("TODO:randomdeviceid".to_owned());
// Add device (TODO: We might not want to call it when using an existing device) // Add device
data.device_add(&user_id, &device_id); data.device_add(&user_id, &device_id);
// Generate a new token for the device // Generate a new token for the device
@ -163,7 +167,7 @@ fn login_route(data: State<Data>, body: Ruma<login::Request>) -> MatrixResult<lo
return MatrixResult(Ok(login::Response { return MatrixResult(Ok(login::Response {
user_id, user_id,
access_token: token, access_token: token,
home_server: Some(data.hostname()), home_server: Some(data.hostname().to_owned()),
device_id, device_id,
well_known: None, well_known: None,
})); }));
@ -217,6 +221,8 @@ fn create_message_event_route(
// Generate event id // Generate event id
let event_id = EventId::try_from("$TODOrandomeventid:localhost").unwrap(); let event_id = EventId::try_from("$TODOrandomeventid:localhost").unwrap();
data.event_add( data.event_add(
&body.room_id,
&event_id,
&Event::RoomMessage(MessageEvent { &Event::RoomMessage(MessageEvent {
content: body.data.clone().into_result().unwrap(), content: body.data.clone().into_result().unwrap(),
event_id: event_id.clone(), event_id: event_id.clone(),
@ -225,8 +231,6 @@ fn create_message_event_route(
sender: body.user_id.clone().expect("user is authenticated"), sender: body.user_id.clone().expect("user is authenticated"),
unsigned: Map::default(), unsigned: Map::default(),
}), }),
&body.room_id,
&event_id,
); );
MatrixResult(Ok(create_message_event::Response { event_id })) MatrixResult(Ok(create_message_event::Response { event_id }))
@ -239,8 +243,8 @@ fn main() {
} }
pretty_env_logger::init(); pretty_env_logger::init();
let data = Data::load_or_create(); let data = Data::load_or_create("localhost");
data.set_hostname("localhost"); data.debug();
rocket::ignite() rocket::ignite()
.mount( .mount(

View file

@ -1,4 +1,7 @@
use std::time::{SystemTime, UNIX_EPOCH}; use std::{
convert::TryInto,
time::{SystemTime, UNIX_EPOCH},
};
pub fn millis_since_unix_epoch() -> js_int::UInt { pub fn millis_since_unix_epoch() -> js_int::UInt {
(SystemTime::now() (SystemTime::now()
@ -8,20 +11,19 @@ pub fn millis_since_unix_epoch() -> js_int::UInt {
.into() .into()
} }
pub fn bytes_to_string(bytes: &[u8]) -> String { pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> {
String::from_utf8(bytes.to_vec()).expect("convert bytes to string") let number = match old {
Some(bytes) => {
let array: [u8; 8] = bytes.try_into().unwrap();
let number = u64::from_be_bytes(array);
number + 1
}
None => 0,
};
Some(number.to_be_bytes().to_vec())
} }
pub fn vec_to_bytes(vec: Vec<String>) -> Vec<u8> { pub fn string_from_bytes(bytes: &[u8]) -> String {
vec.into_iter() String::from_utf8(bytes.to_vec()).expect("bytes are valid utf8")
.map(|string| string.into_bytes())
.collect::<Vec<Vec<u8>>>()
.join(&0)
}
pub fn bytes_to_vec(bytes: &[u8]) -> Vec<String> {
bytes
.split(|&b| b == 0)
.map(|bytes_string| bytes_to_string(bytes_string))
.collect::<Vec<String>>()
} }