From 3a76fda92b8694c12d2489f35eaa92213725035c Mon Sep 17 00:00:00 2001
From: Jonathan de Jong <jonathan@automatia.nl>
Date: Sun, 11 Jul 2021 15:41:10 +0200
Subject: [PATCH] incorperate feedback

---
 Cargo.toml                          |  1 -
 src/database.rs                     | 39 ++++++++++++++---------------
 src/database/abstraction/rocksdb.rs |  8 +++---
 src/database/abstraction/sled.rs    |  6 ++---
 src/database/abstraction/sqlite.rs  | 28 ++++++++++-----------
 src/main.rs                         |  3 ---
 6 files changed, 40 insertions(+), 45 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 7edf6417..896140cb 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -25,7 +25,6 @@ tokio = "1.2.0"
 # Used for storing data permanently
 sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true }
 rocksdb = { version = "0.16.0", features = ["multi-threaded-cf"], optional = true }
-# sqlx = { version = "0.5.5", features = ["sqlite", "runtime-tokio-rustls"], optional = true }
 #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] }
 
 # Used for the http request / response body type for Ruma endpoints used with reqwest
diff --git a/src/database.rs b/src/database.rs
index e5bb8656..7fcee02a 100644
--- a/src/database.rs
+++ b/src/database.rs
@@ -42,10 +42,8 @@ use self::proxy::ProxyConfig;
 pub struct Config {
     server_name: Box<ServerName>,
     database_path: String,
-    #[serde(default = "default_cache_capacity")]
-    cache_capacity: u32,
-    #[serde(default = "default_sqlite_cache_kib")]
-    sqlite_cache_kib: u32,
+    #[serde(default = "default_db_cache_capacity")]
+    db_cache_capacity: u32,
     #[serde(default = "default_sqlite_read_pool_size")]
     sqlite_read_pool_size: usize,
     #[serde(default = "false_fn")]
@@ -83,14 +81,10 @@ fn true_fn() -> bool {
     true
 }
 
-fn default_cache_capacity() -> u32 {
+fn default_db_cache_capacity() -> u32 {
     1024 * 1024 * 1024
 }
 
-fn default_sqlite_cache_kib() -> u32 {
-    2000
-}
-
 fn default_sqlite_read_pool_size() -> usize {
     num_cpus::get().max(1)
 }
@@ -116,13 +110,13 @@ fn default_log() -> String {
 }
 
 #[cfg(feature = "sled")]
-pub type Engine = abstraction::sled::SledEngine;
+pub type Engine = abstraction::sled::Engine;
 
 #[cfg(feature = "rocksdb")]
-pub type Engine = abstraction::rocksdb::RocksDbEngine;
+pub type Engine = abstraction::rocksdb::Engine;
 
 #[cfg(feature = "sqlite")]
-pub type Engine = abstraction::sqlite::SqliteEngine;
+pub type Engine = abstraction::sqlite::Engine;
 
 pub struct Database {
     _db: Arc<Engine>,
@@ -268,7 +262,7 @@ impl Database {
             globals: globals::Globals::load(
                 builder.open_tree("global")?,
                 builder.open_tree("server_signingkeys")?,
-                config,
+                config.clone(),
             )?,
         }));
 
@@ -372,6 +366,9 @@ impl Database {
 
         drop(guard);
 
+        #[cfg(feature = "sqlite")]
+        Self::start_wal_clean_task(&db, &config).await;
+
         Ok(db)
     }
 
@@ -481,6 +478,7 @@ impl Database {
     #[cfg(feature = "sqlite")]
     pub async fn start_wal_clean_task(lock: &Arc<TokioRwLock<Self>>, config: &Config) {
         use tokio::{
+            select,
             signal::unix::{signal, SignalKind},
             time::{interval, timeout},
         };
@@ -501,13 +499,14 @@ impl Database {
             let mut s = signal(SignalKind::hangup()).unwrap();
 
             loop {
-                if do_timer {
-                    i.tick().await;
-                    log::info!(target: "wal-trunc", "Timer ticked")
-                } else {
-                    s.recv().await;
-                    log::info!(target: "wal-trunc", "Received SIGHUP")
-                }
+                select! {
+                    _ = i.tick(), if do_timer => {
+                        log::info!(target: "wal-trunc", "Timer ticked")
+                    }
+                    _ = s.recv() => {
+                        log::info!(target: "wal-trunc", "Received SIGHUP")
+                    }
+                };
 
                 if let Some(arc) = Weak::upgrade(&weak) {
                     log::info!(target: "wal-trunc", "Locking...");
diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs
index 88b6297c..b9961302 100644
--- a/src/database/abstraction/rocksdb.rs
+++ b/src/database/abstraction/rocksdb.rs
@@ -7,15 +7,15 @@ use super::{DatabaseEngine, Tree};
 
 use std::{collections::BTreeMap, sync::RwLock};
 
-pub struct RocksDbEngine(rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>);
+pub struct Engine(rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>);
 
 pub struct RocksDbEngineTree<'a> {
-    db: Arc<RocksDbEngine>,
+    db: Arc<Engine>,
     name: &'a str,
     watchers: RwLock<BTreeMap<Vec<u8>, Vec<tokio::sync::oneshot::Sender<()>>>>,
 }
 
-impl DatabaseEngine for RocksDbEngine {
+impl DatabaseEngine for Engine {
     fn open(config: &Config) -> Result<Arc<Self>> {
         let mut db_opts = rocksdb::Options::default();
         db_opts.create_if_missing(true);
@@ -45,7 +45,7 @@ impl DatabaseEngine for RocksDbEngine {
                 .map(|name| rocksdb::ColumnFamilyDescriptor::new(name, options.clone())),
         )?;
 
-        Ok(Arc::new(RocksDbEngine(db)))
+        Ok(Arc::new(Engine(db)))
     }
 
     fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>> {
diff --git a/src/database/abstraction/sled.rs b/src/database/abstraction/sled.rs
index 557e8a0e..8c7f80d3 100644
--- a/src/database/abstraction/sled.rs
+++ b/src/database/abstraction/sled.rs
@@ -5,13 +5,13 @@ use std::{future::Future, pin::Pin, sync::Arc};
 
 use super::{DatabaseEngine, Tree};
 
-pub struct SledEngine(sled::Db);
+pub struct Engine(sled::Db);
 
 pub struct SledEngineTree(sled::Tree);
 
-impl DatabaseEngine for SledEngine {
+impl DatabaseEngine for Engine {
     fn open(config: &Config) -> Result<Arc<Self>> {
-        Ok(Arc::new(SledEngine(
+        Ok(Arc::new(Engine(
             sled::Config::default()
                 .path(&config.database_path)
                 .cache_capacity(config.cache_capacity as u64)
diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs
index ff06fab2..fe548133 100644
--- a/src/database/abstraction/sqlite.rs
+++ b/src/database/abstraction/sqlite.rs
@@ -119,22 +119,22 @@ impl Pool {
     }
 }
 
-pub struct SqliteEngine {
+pub struct Engine {
     pool: Pool,
 }
 
-impl DatabaseEngine for SqliteEngine {
+impl DatabaseEngine for Engine {
     fn open(config: &Config) -> Result<Arc<Self>> {
         let pool = Pool::new(
             Path::new(&config.database_path).join("conduit.db"),
             config.sqlite_read_pool_size,
-            config.sqlite_cache_kib,
+            config.db_cache_capacity / 1024, // bytes -> kb
         )?;
 
         pool.write_lock()
             .execute("CREATE TABLE IF NOT EXISTS _noop (\"key\" INT)", params![])?;
 
-        let arc = Arc::new(SqliteEngine { pool });
+        let arc = Arc::new(Engine { pool });
 
         Ok(arc)
     }
@@ -166,7 +166,7 @@ impl DatabaseEngine for SqliteEngine {
     }
 }
 
-impl SqliteEngine {
+impl Engine {
     pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
         self.pool
             .write_lock()
@@ -185,7 +185,7 @@ impl SqliteEngine {
 }
 
 pub struct SqliteTable {
-    engine: Arc<SqliteEngine>,
+    engine: Arc<Engine>,
     name: String,
     watchers: RwLock<BTreeMap<Vec<u8>, Vec<Sender<()>>>>,
 }
@@ -257,19 +257,19 @@ impl Tree for SqliteTable {
     }
 
     fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
-        {
-            let guard = self.engine.pool.write_lock();
+        let guard = self.engine.pool.write_lock();
 
-            let start = Instant::now();
+        let start = Instant::now();
 
-            self.insert_with_guard(&guard, key, value)?;
+        self.insert_with_guard(&guard, key, value)?;
 
-            let elapsed = start.elapsed();
-            if elapsed > MILLI {
-                debug!("insert:    took {:012?} : {}", elapsed, &self.name);
-            }
+        let elapsed = start.elapsed();
+        if elapsed > MILLI {
+            debug!("insert:    took {:012?} : {}", elapsed, &self.name);
         }
 
+        drop(guard);
+
         let watchers = self.watchers.read();
         let mut triggered = Vec::new();
 
diff --git a/src/main.rs b/src/main.rs
index 22c44b54..034c39ea 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -202,9 +202,6 @@ async fn main() {
         .await
         .expect("config is valid");
 
-    #[cfg(feature = "sqlite")]
-    Database::start_wal_clean_task(&db, &config).await;
-
     if config.allow_jaeger {
         let (tracer, _uninstall) = opentelemetry_jaeger::new_pipeline()
             .with_service_name("conduit")