diff --git a/Cargo.lock b/Cargo.lock index ac26621d4d..1d2f413586 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3345,6 +3345,7 @@ version = "2.2.7" dependencies = [ "aleo-std", "indexmap 2.5.0", + "lru", "parking_lot", "snarkvm", "tracing", diff --git a/node/bft/storage-service/Cargo.toml b/node/bft/storage-service/Cargo.toml index c74fb09307..4aa60eab29 100644 --- a/node/bft/storage-service/Cargo.toml +++ b/node/bft/storage-service/Cargo.toml @@ -29,6 +29,9 @@ workspace = true version = "2.1" features = [ "serde", "rayon" ] +[dependencies.lru] +version = "0.12.1" + [dependencies.parking_lot] version = "0.12" optional = true diff --git a/node/bft/storage-service/src/persistent.rs b/node/bft/storage-service/src/persistent.rs index 8d7e31bca2..7ccc6c578a 100644 --- a/node/bft/storage-service/src/persistent.rs +++ b/node/bft/storage-service/src/persistent.rs @@ -16,6 +16,7 @@ use crate::StorageService; use snarkvm::{ ledger::{ + committee::Committee, narwhal::{BatchHeader, Transmission, TransmissionID}, store::{ cow_to_cloned, @@ -34,9 +35,12 @@ use snarkvm::{ use aleo_std::StorageMode; use indexmap::{indexset, IndexSet}; +use lru::LruCache; +use parking_lot::Mutex; use std::{ borrow::Cow, collections::{HashMap, HashSet}, + num::NonZeroUsize, }; use tracing::error; @@ -47,11 +51,20 @@ pub struct BFTPersistentStorage { transmissions: DataMap, (Transmission, IndexSet>)>, /// The map of `aborted transmission ID` to `certificate IDs` entries. aborted_transmission_ids: DataMap, IndexSet>>, + /// The LRU cache for `transmission ID` to `(transmission, certificate IDs)` entries that are part of the persistent storage. + cache_transmissions: Mutex, (Transmission, IndexSet>)>>, + /// The LRU cache for `aborted transmission ID` to `certificate IDs` entries that are part of the persistent storage. + cache_aborted_transmission_ids: Mutex, IndexSet>>>, } impl BFTPersistentStorage { /// Initializes a new BFT persistent storage service. pub fn open(storage_mode: StorageMode) -> Result { + let capacity = NonZeroUsize::new( + (Committee::::MAX_COMMITTEE_SIZE as usize) * (BatchHeader::::MAX_TRANSMISSIONS_PER_BATCH) * 2, + ) + .unwrap(); + Ok(Self { transmissions: internal::RocksDB::open_map(N::ID, storage_mode.clone(), MapID::BFT(BFTMap::Transmissions))?, aborted_transmission_ids: internal::RocksDB::open_map( @@ -59,12 +72,19 @@ impl BFTPersistentStorage { storage_mode, MapID::BFT(BFTMap::AbortedTransmissionIDs), )?, + cache_transmissions: Mutex::new(LruCache::new(capacity)), + cache_aborted_transmission_ids: Mutex::new(LruCache::new(capacity)), }) } - /// Initializes a new BFT persistent storage service. + /// Initializes a new BFT persistent storage service for testing. #[cfg(any(test, feature = "test"))] pub fn open_testing(temp_dir: std::path::PathBuf, dev: Option) -> Result { + let capacity = NonZeroUsize::new( + (Committee::::MAX_COMMITTEE_SIZE as usize) * (BatchHeader::::MAX_TRANSMISSIONS_PER_BATCH) * 2, + ) + .unwrap(); + Ok(Self { transmissions: internal::RocksDB::open_map_testing( temp_dir.clone(), @@ -76,6 +96,8 @@ impl BFTPersistentStorage { dev, MapID::BFT(BFTMap::AbortedTransmissionIDs), )?, + cache_transmissions: Mutex::new(LruCache::new(capacity)), + cache_aborted_transmission_ids: Mutex::new(LruCache::new(capacity)), }) } } @@ -102,7 +124,12 @@ impl StorageService for BFTPersistentStorage { /// Returns the transmission for the given `transmission ID`. /// If the transmission ID does not exist in storage, `None` is returned. fn get_transmission(&self, transmission_id: TransmissionID) -> Option> { - // Get the transmission. + // Try to get the transmission from the cache first. + if let Some((transmission, _)) = self.cache_transmissions.lock().get_mut(&transmission_id) { + return Some(transmission.clone()); + } + + // If not found in cache, check persistent storage. match self.transmissions.get_confirmed(&transmission_id) { Ok(Some(Cow::Owned((transmission, _)))) => Some(transmission), Ok(Some(Cow::Borrowed((transmission, _)))) => Some(transmission.clone()), @@ -153,24 +180,19 @@ impl StorageService for BFTPersistentStorage { aborted_transmission_ids: HashSet>, mut missing_transmissions: HashMap, Transmission>, ) { - // Inserts the following: - // - Inserts **only the missing** transmissions from storage. - // - Inserts the certificate ID into the corresponding set for **all** transmissions. + // First, handle the non-aborted transmissions. 'outer: for transmission_id in transmission_ids { - // Retrieve the transmission entry. - match self.transmissions.get_confirmed(&transmission_id) { + // Try to fetch from the persistent storage. + let (transmission, certificate_ids) = match self.transmissions.get_confirmed(&transmission_id) { Ok(Some(entry)) => { + // The transmission exists in storage; update its certificate IDs. let (transmission, mut certificate_ids) = cow_to_cloned!(entry); - // Insert the certificate ID into the set. certificate_ids.insert(certificate_id); - // Update the transmission entry. - if let Err(e) = self.transmissions.insert(transmission_id, (transmission, certificate_ids)) { - error!("Failed to insert transmission {transmission_id} into storage - {e}"); - continue 'outer; - } + (transmission, certificate_ids) } Ok(None) => { - // Retrieve the missing transmission. + // The transmission is missing from persistent storage. + // Check if it exists in the `missing_transmissions` map provided. let Some(transmission) = missing_transmissions.remove(&transmission_id) else { if !aborted_transmission_ids.contains(&transmission_id) && !self.contains_transmission(transmission_id) @@ -181,45 +203,46 @@ impl StorageService for BFTPersistentStorage { }; // Prepare the set of certificate IDs. let certificate_ids = indexset! { certificate_id }; - // Insert the transmission and a new set with the certificate ID. - if let Err(e) = self.transmissions.insert(transmission_id, (transmission, certificate_ids)) { - error!("Failed to insert transmission {transmission_id} into storage - {e}"); - continue 'outer; - } + (transmission, certificate_ids) } Err(e) => { + // Handle any errors during the retrieval. error!("Failed to process the 'insert' for transmission {transmission_id} into storage - {e}"); - continue 'outer; + continue; } + }; + // Insert the transmission into persistent storage. + if let Err(e) = self.transmissions.insert(transmission_id, (transmission.clone(), certificate_ids.clone())) + { + error!("Failed to insert transmission {transmission_id} into storage - {e}"); } + // Insert the transmission into the cache. + self.cache_transmissions.lock().put(transmission_id, (transmission, certificate_ids)); } - // Inserts the aborted transmission IDs. + + // Next, handle the aborted transmission IDs. for aborted_transmission_id in aborted_transmission_ids { - // Retrieve the transmission entry. - match self.aborted_transmission_ids.get_confirmed(&aborted_transmission_id) { + let certificate_ids = match self.aborted_transmission_ids.get_confirmed(&aborted_transmission_id) { Ok(Some(entry)) => { let mut certificate_ids = cow_to_cloned!(entry); // Insert the certificate ID into the set. certificate_ids.insert(certificate_id); - // Update the transmission entry. - if let Err(e) = self.aborted_transmission_ids.insert(aborted_transmission_id, certificate_ids) { - error!("Failed to insert aborted transmission ID {aborted_transmission_id} into storage - {e}"); - } - } - Ok(None) => { - // Prepare the set of certificate IDs. - let certificate_ids = indexset! { certificate_id }; - // Insert the transmission and a new set with the certificate ID. - if let Err(e) = self.aborted_transmission_ids.insert(aborted_transmission_id, certificate_ids) { - error!("Failed to insert aborted transmission ID {aborted_transmission_id} into storage - {e}"); - } + certificate_ids } + Ok(None) => indexset! { certificate_id }, Err(e) => { error!( "Failed to process the 'insert' for aborted transmission ID {aborted_transmission_id} into storage - {e}" ); + continue; } + }; + // Insert the certificate IDs into the persistent storage. + if let Err(e) = self.aborted_transmission_ids.insert(aborted_transmission_id, certificate_ids.clone()) { + error!("Failed to insert aborted transmission ID {aborted_transmission_id} into storage - {e}"); } + // Insert the certificate IDs into the cache. + self.cache_aborted_transmission_ids.lock().put(aborted_transmission_id, certificate_ids); } } diff --git a/node/router/tests/disconnect.rs b/node/router/tests/disconnect.rs index 4250ea5048..fca0bba09f 100644 --- a/node/router/tests/disconnect.rs +++ b/node/router/tests/disconnect.rs @@ -19,6 +19,7 @@ use common::*; use snarkos_node_tcp::{protocols::Handshake, P2P}; use core::time::Duration; +use deadline::deadline; #[tokio::test] async fn test_disconnect_without_handshake() { @@ -34,8 +35,12 @@ async fn test_disconnect_without_handshake() { // Connect node0 to node1. node0.connect(node1.local_ip()); - // Sleep briefly. - tokio::time::sleep(Duration::from_millis(200)).await; + // Await both nodes being connected. + let node0_ = node0.clone(); + let node1_ = node1.clone(); + deadline!(Duration::from_secs(1), move || { + node0_.tcp().num_connected() == 1 && node1_.tcp().num_connected() == 1 + }); print_tcp!(node0); print_tcp!(node1); @@ -50,8 +55,9 @@ async fn test_disconnect_without_handshake() { // collection of connected peers is only altered during the handshake, // as well as the address resolver needed for the higher-level calls node0.tcp().disconnect(node1.local_ip()).await; - // Sleep briefly. - tokio::time::sleep(Duration::from_millis(100)).await; + // Await disconnection. + let node0_ = node0.clone(); + deadline!(Duration::from_secs(1), move || { node0_.tcp().num_connected() == 0 }); print_tcp!(node0); print_tcp!(node1); @@ -80,8 +86,12 @@ async fn test_disconnect_with_handshake() { // Connect node0 to node1. node0.connect(node1.local_ip()); - // Sleep briefly. - tokio::time::sleep(Duration::from_millis(1000)).await; + // Await for the nodes to be connected. + let node0_ = node0.clone(); + let node1_ = node1.clone(); + deadline!(Duration::from_secs(1), move || { + node0_.tcp().num_connected() == 1 && node1_.tcp().num_connected() == 1 + }); print_tcp!(node0); print_tcp!(node1); @@ -98,8 +108,9 @@ async fn test_disconnect_with_handshake() { // Disconnect node0 from node1. node0.disconnect(node1.local_ip()); - // Sleep briefly. - tokio::time::sleep(Duration::from_millis(100)).await; + // Await nodes being disconnected. + let node0_ = node0.clone(); + deadline!(Duration::from_secs(1), move || { node0_.tcp().num_connected() == 0 }); print_tcp!(node0); print_tcp!(node1);