From b1d9ec3efccafaf887da1b54e4b3ef2bfa4d84a3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timo=20K=C3=B6sters?= <timo@koesters.xyz>
Date: Mon, 20 Dec 2021 10:16:22 +0100
Subject: [PATCH] fix: atomic increment

---
 Cargo.toml                           |  2 +-
 src/database/abstraction/rocksdb.rs  | 24 ++++++++++++++++--------
 src/database/abstraction/watchers.rs |  8 ++++----
 3 files changed, 21 insertions(+), 13 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 0a2b4459..6241b6a8 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -90,7 +90,7 @@ backend_sled = ["sled"]
 backend_sqlite = ["sqlite"]
 backend_heed = ["heed", "crossbeam"]
 backend_rocksdb = ["rocksdb"]
-sqlite = ["rusqlite", "parking_lot", "crossbeam", "tokio/signal"]
+sqlite = ["rusqlite", "parking_lot", "tokio/signal"]
 conduit_bin = [] # TODO: add rocket to this when it is optional
 
 [[bin]]
diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs
index 825c02e0..b2142dfe 100644
--- a/src/database/abstraction/rocksdb.rs
+++ b/src/database/abstraction/rocksdb.rs
@@ -1,11 +1,6 @@
-use super::super::Config;
+use super::{super::Config, watchers::Watchers, DatabaseEngine, Tree};
 use crate::{utils, Result};
-
-use std::{future::Future, pin::Pin, sync::Arc};
-
-use super::{DatabaseEngine, Tree};
-
-use std::{collections::HashMap, sync::RwLock};
+use std::{future::Future, pin::Pin, sync::Arc, collections::HashMap, sync::RwLock};
 
 pub struct Engine {
     rocks: rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>,
@@ -16,6 +11,7 @@ pub struct RocksDbEngineTree<'a> {
     db: Arc<Engine>,
     name: &'a str,
     watchers: Watchers,
+    write_lock: RwLock<()>
 }
 
 impl DatabaseEngine for Engine {
@@ -77,6 +73,7 @@ impl DatabaseEngine for Engine {
             name,
             db: Arc::clone(self),
             watchers: Watchers::default(),
+            write_lock: RwLock::new(()),
         }))
     }
 
@@ -98,8 +95,12 @@ impl Tree for RocksDbEngineTree<'_> {
     }
 
     fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
+        let lock = self.write_lock.read().unwrap();
         self.db.rocks.put_cf(self.cf(), key, value)?;
+        drop(lock);
+
         self.watchers.wake(key);
+
         Ok(())
     }
 
@@ -148,20 +149,27 @@ impl Tree for RocksDbEngineTree<'_> {
     }
 
     fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
-        // TODO: make atomic
+        let lock = self.write_lock.write().unwrap();
+
         let old = self.db.rocks.get_cf(self.cf(), &key)?;
         let new = utils::increment(old.as_deref()).unwrap();
         self.db.rocks.put_cf(self.cf(), key, &new)?;
+
+        drop(lock);
         Ok(new)
     }
 
     fn increment_batch<'a>(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
+        let lock = self.write_lock.write().unwrap();
+
         for key in iter {
             let old = self.db.rocks.get_cf(self.cf(), &key)?;
             let new = utils::increment(old.as_deref()).unwrap();
             self.db.rocks.put_cf(self.cf(), key, new)?;
         }
 
+        drop(lock);
+
         Ok(())
     }
 
diff --git a/src/database/abstraction/watchers.rs b/src/database/abstraction/watchers.rs
index 404f3f06..fec1f27a 100644
--- a/src/database/abstraction/watchers.rs
+++ b/src/database/abstraction/watchers.rs
@@ -1,6 +1,6 @@
-use parking_lot::RwLock;
 use std::{
     collections::{hash_map, HashMap},
+    sync::RwLock,
     future::Future,
     pin::Pin,
 };
@@ -16,7 +16,7 @@ impl Watchers {
         &'a self,
         prefix: &[u8],
     ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
-        let mut rx = match self.watchers.write().entry(prefix.to_vec()) {
+        let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) {
             hash_map::Entry::Occupied(o) => o.get().1.clone(),
             hash_map::Entry::Vacant(v) => {
                 let (tx, rx) = tokio::sync::watch::channel(());
@@ -31,7 +31,7 @@ impl Watchers {
         })
     }
     pub(super) fn wake(&self, key: &[u8]) {
-        let watchers = self.watchers.read();
+        let watchers = self.watchers.read().unwrap();
         let mut triggered = Vec::new();
 
         for length in 0..=key.len() {
@@ -43,7 +43,7 @@ impl Watchers {
         drop(watchers);
 
         if !triggered.is_empty() {
-            let mut watchers = self.watchers.write();
+            let mut watchers = self.watchers.write().unwrap();
             for prefix in triggered {
                 if let Some(tx) = watchers.remove(prefix) {
                     let _ = tx.0.send(());