From 34a53ce20a49d9a9d6ffd98a04d0164a73a16a79 Mon Sep 17 00:00:00 2001 From: timokoesters Date: Sat, 28 Mar 2020 18:50:02 +0100 Subject: [PATCH] Better database structure --- Cargo.lock | 69 +++++++++++++++++++++++++++++- Cargo.toml | 2 +- Rocket.toml | 3 ++ rust-toolchain | 1 + src/data.rs | 39 +++++++++++++++++ src/main.rs | 102 +++++++++++++++++++++++--------------------- src/ruma_wrapper.rs | 58 +++++++++++++++++-------- 7 files changed, 206 insertions(+), 68 deletions(-) create mode 100644 Rocket.toml create mode 100644 rust-toolchain create mode 100644 src/data.rs diff --git a/Cargo.lock b/Cargo.lock index 4e7af27b..d024f974 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,6 +54,15 @@ dependencies = [ "safemem", ] +[[package]] +name = "base64" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b25d992356d2eb0ed82172f5248873db5560c4721f564b13cb5193bda5e668e" +dependencies = [ + "byteorder", +] + [[package]] name = "base64" version = "0.11.0" @@ -356,6 +365,18 @@ dependencies = [ "url 1.7.2", ] +[[package]] +name = "hyper-sync-rustls" +version = "0.3.0-rc.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53be239c980459955c0f0af3f13190ead511d7d4bdaeab8127c011b94d8558de" +dependencies = [ + "hyper", + "rustls", + "webpki", + "webpki-roots", +] + [[package]] name = "idna" version = "0.1.5" @@ -770,9 +791,11 @@ source = "git+https://github.com/SergioBenitez/Rocket.git#06e146e7d18d7c4aab423d dependencies = [ "cookie", "hyper", + "hyper-sync-rustls", "indexmap", "pear", "percent-encoding 1.0.1", + "rustls", "smallvec", "state", "time 0.2.9", @@ -879,6 +902,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustls" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b25a18b1bf7387f0145e7f8324e700805aade3842dd3db2e74e4cdeb4677c09e" +dependencies = [ + "base64 0.10.1", + "log 0.4.8", + "ring", + "sct", + "webpki", +] + [[package]] name = "rustversion" version = "1.0.2" @@ -908,6 +944,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "sct" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3042af939fca8c3453b7af0f1c66e533a15a86169e39de2657310ade8f98d3c" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "semver" version = "0.9.0" @@ -945,9 +991,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.48" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9371ade75d4c2d6cb154141b9752cf3781ec9c05e0e5cf35060e1e70ee7b9c25" +checksum = "02044a6a92866fd61624b3db4d2c9dccc2feabbc6be490b87611bf285edbac55" dependencies = [ "itoa", "ryu", @@ -1347,6 +1393,25 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1f50e1972865d6b1adb54167d1c8ed48606004c2c9d0ea5f1eeb34d95e863ef" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "webpki-roots" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91cd5736df7f12a964a5067a12c62fa38e1bd8080aff1f80bc29be7c80d19ab4" +dependencies = [ + "webpki", +] + [[package]] name = "winapi" version = "0.3.8" diff --git a/Cargo.toml b/Cargo.toml index 31cfd301..32c71a25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -rocket = { git = "https://github.com/SergioBenitez/Rocket.git" } +rocket = { git = "https://github.com/SergioBenitez/Rocket.git", features = ["tls"] } http = "0.2.1" ruma-client-api = { git = "https://github.com/ruma/ruma-client-api" } pretty_env_logger = "0.4.0" diff --git a/Rocket.toml b/Rocket.toml new file mode 100644 index 00000000..d18ee979 --- /dev/null +++ b/Rocket.toml @@ -0,0 +1,3 @@ +#[global.tls] +#certs = "/etc/ssl/certs/ssl-cert-snakeoil.pem" +#key = "/etc/ssl/private/ssl-cert-snakeoil.key" diff --git a/rust-toolchain b/rust-toolchain new file mode 100644 index 00000000..bf867e0a --- /dev/null +++ b/rust-toolchain @@ -0,0 +1 @@ +nightly diff --git a/src/data.rs b/src/data.rs new file mode 100644 index 00000000..52fe7af9 --- /dev/null +++ b/src/data.rs @@ -0,0 +1,39 @@ +use directories::ProjectDirs; +use ruma_identifiers::UserId; + +pub struct Data(sled::Db); + +impl Data { + pub fn set_hostname(&self, hostname: &str) { + self.0.insert("hostname", hostname).unwrap(); + } + pub fn hostname(&self) -> String { + String::from_utf8(self.0.get("hostname").unwrap().unwrap().to_vec()).unwrap() + } + pub fn load_or_create() -> Self { + Data( + sled::open( + ProjectDirs::from("xyz", "koesters", "matrixserver") + .unwrap() + .data_dir(), + ) + .unwrap(), + ) + } + + pub fn user_exists(&self, user_id: &UserId) -> bool { + self.0 + .open_tree("username_password") + .unwrap() + .contains_key(user_id.to_string()) + .unwrap() + } + + pub fn user_add(&self, user_id: UserId, password: Option) { + self.0 + .open_tree("username_password") + .unwrap() + .insert(user_id.to_string(), &*password.unwrap_or_default()) + .unwrap(); + } +} diff --git a/src/main.rs b/src/main.rs index fa6cb143..419a01c1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,42 +1,47 @@ #![feature(proc_macro_hygiene, decl_macro)] +mod data; mod ruma_wrapper; -use { - directories::ProjectDirs, - log::debug, - rocket::{get, post, put, routes, State}, - ruma_client_api::{ - error::{Error, ErrorKind}, - r0::{ - account::register, alias::get_alias, membership::join_room_by_id, - message::create_message_event, session::login, - }, - unversioned::get_supported_versions, +use data::Data; +use log::debug; +use rocket::{get, post, put, routes, State}; +use ruma_client_api::{ + error::{Error, ErrorKind}, + r0::{ + account::register, alias::get_alias, membership::join_room_by_id, + message::create_message_event, session::login, }, - ruma_identifiers::UserId, - ruma_wrapper::{MatrixResult, Ruma}, - sled::Db, - std::{collections::HashMap, convert::TryInto}, + unversioned::get_supported_versions, }; +use ruma_identifiers::UserId; +use ruma_wrapper::{MatrixResult, Ruma}; +use std::{collections::HashMap, convert::TryInto}; #[get("/_matrix/client/versions")] fn get_supported_versions_route() -> MatrixResult { MatrixResult(Ok(get_supported_versions::Response { - versions: vec!["r0.6.0".to_owned()], + versions: vec![ + "r0.0.1".to_owned(), + "r0.1.0".to_owned(), + "r0.2.0".to_owned(), + "r0.3.0".to_owned(), + "r0.4.0".to_owned(), + "r0.5.0".to_owned(), + "r0.6.0".to_owned(), + ], unstable_features: HashMap::new(), })) } #[post("/_matrix/client/r0/register", data = "")] fn register_route( - db: State, + data: State, body: Ruma, ) -> MatrixResult { - let users = db.open_tree("users").unwrap(); - let user_id: UserId = match (*format!( - "@{}:localhost", - body.username.clone().unwrap_or("randomname".to_owned()) + "@{}:{}", + body.username.clone().unwrap_or("randomname".to_owned()), + data.hostname() )) .try_into() { @@ -51,7 +56,7 @@ fn register_route( Ok(user_id) => user_id, }; - if users.contains_key(user_id.to_string()).unwrap() { + if data.user_exists(&user_id) { debug!("ID already taken"); return MatrixResult(Err(Error { kind: ErrorKind::UserInUse, @@ -60,37 +65,42 @@ fn register_route( })); } - users - .insert( - user_id.to_string(), - &*body.password.clone().unwrap_or_default(), - ) - .unwrap(); + data.user_add(user_id.clone(), body.password.clone()); MatrixResult(Ok(register::Response { access_token: "randomtoken".to_owned(), - home_server: "localhost".to_owned(), + home_server: data.hostname(), user_id, device_id: body.device_id.clone().unwrap_or("randomid".to_owned()), })) } #[post("/_matrix/client/r0/login", data = "")] -fn login_route(db: State, body: Ruma) -> MatrixResult { - let user_id = if let login::UserInfo::MatrixId(username) = &body.user { - let user_id = format!("@{}:localhost", username); - let users = db.open_tree("users").unwrap(); - if !users.contains_key(user_id.clone()).unwrap() { - dbg!(); +fn login_route(data: State, body: Ruma) -> MatrixResult { + let username = if let login::UserInfo::MatrixId(mut username) = body.user.clone() { + if !username.contains(':') { + username = format!("@{}:{}", username, data.hostname()); + } + if let Ok(user_id) = (*username).try_into() { + if !data.user_exists(&user_id) { + debug!("Userid does not exist. Can't log in."); + return MatrixResult(Err(Error { + kind: ErrorKind::Forbidden, + message: "UserId not found.".to_owned(), + status_code: http::StatusCode::BAD_REQUEST, + })); + } + user_id + } else { + debug!("Invalid UserId."); return MatrixResult(Err(Error { - kind: ErrorKind::Forbidden, - message: "UserId not found.".to_owned(), + kind: ErrorKind::Unknown, + message: "Bad login type.".to_owned(), status_code: http::StatusCode::BAD_REQUEST, })); } - user_id } else { - dbg!(); + debug!("Bad login type"); return MatrixResult(Err(Error { kind: ErrorKind::Unknown, message: "Bad login type.".to_owned(), @@ -99,7 +109,7 @@ fn login_route(db: State, body: Ruma) -> MatrixResult, ) -> MatrixResult { - dbg!(body.0); + dbg!(body); MatrixResult(Ok(create_message_event::Response { event_id: "$randomeventid".try_into().unwrap(), })) @@ -161,12 +171,8 @@ fn main() { } pretty_env_logger::init(); - let db = sled::open( - ProjectDirs::from("xyz", "koesters", "matrixserver") - .unwrap() - .data_dir(), - ) - .unwrap(); + let data = Data::load_or_create(); + data.set_hostname("localhost"); rocket::ignite() .mount( @@ -180,6 +186,6 @@ fn main() { create_message_event_route, ], ) - .manage(db) + .manage(data) .launch(); } diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index dda584f0..1f294507 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -1,24 +1,28 @@ -use { - rocket::data::{FromDataSimple, Outcome}, - rocket::http::Status, - rocket::response::Responder, - rocket::Request, - rocket::{Data, Outcome::*}, - ruma_client_api::error::Error, - std::fmt::Debug, - std::ops::Deref, - std::{ - convert::{TryFrom, TryInto}, - io::{Cursor, Read}, - }, +use rocket::{ + data::{FromDataSimple, Outcome}, + http::Status, + response::Responder, + Data, + Outcome::*, + Request, +}; +use ruma_client_api::error::Error; +use std::{ + convert::{TryFrom, TryInto}, + fmt, + io::{Cursor, Read}, + ops::Deref, }; const MESSAGE_LIMIT: u64 = 65535; -pub struct Ruma(pub T); +pub struct Ruma { + body: T, + headers: http::HeaderMap, +} impl>>> FromDataSimple for Ruma where - T::Error: Debug, + T::Error: fmt::Debug, { type Error = (); @@ -35,10 +39,11 @@ where handle.read_to_end(&mut body).unwrap(); let http_request = http_request.body(body).unwrap(); + let headers = http_request.headers().clone(); log::info!("{:?}", http_request); match T::try_from(http_request) { - Ok(r) => Success(Ruma(r)), + Ok(t) => Success(Ruma { body: t, headers }), Err(e) => { log::error!("{:?}", e); Failure((Status::InternalServerError, ())) @@ -51,7 +56,16 @@ impl Deref for Ruma { type Target = T; fn deref(&self) -> &Self::Target { - &self.0 + &self.body + } +} + +impl fmt::Debug for Ruma { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Ruma") + .field("body", &self.body) + .field("headers", &self.headers) + .finish() } } @@ -79,6 +93,16 @@ impl<'r, T: TryInto>>> Responder<'r> for MatrixResult response .raw_header(header.0.to_string(), header.1.to_str().unwrap().to_owned()); } + + response.raw_header("Access-Control-Allow-Origin", "*"); + response.raw_header( + "Access-Control-Allow-Methods", + "GET, POST, PUT, DELETE, OPTIONS", + ); + response.raw_header( + "Access-Control-Allow-Headers", + "Origin, X-Requested-With, Content-Type, Accept, Authorization", + ); response.ok() } Err(_) => Err(Status::InternalServerError),