diff --git a/crates/blockifier/src/concurrency/test_utils.rs b/crates/blockifier/src/concurrency/test_utils.rs index 504e239bbf..0a55856c3d 100644 --- a/crates/blockifier/src/concurrency/test_utils.rs +++ b/crates/blockifier/src/concurrency/test_utils.rs @@ -1,3 +1,8 @@ +use rstest::fixture; +use starknet_api::core::{ClassHash, ContractAddress, PatriciaKey}; +use starknet_api::hash::StarkHash; +use starknet_api::{class_hash, contract_address, patricia_key}; + use crate::concurrency::versioned_state::{ThreadSafeVersionedState, VersionedState}; use crate::context::BlockContext; use crate::execution::call_info::CallInfo; @@ -7,6 +12,20 @@ use crate::test_utils::dict_state_reader::DictStateReader; use crate::transaction::account_transaction::AccountTransaction; use crate::transaction::transactions::ExecutableTransaction; +// Fixtures. + +#[fixture] +pub fn contract_address() -> ContractAddress { + contract_address!("0x18031991") +} + +#[fixture] +pub fn class_hash() -> ClassHash { + class_hash!(27_u8) +} + +// Macros. + #[macro_export] macro_rules! default_scheduler { ($chunk_size:ident : $chunk:expr , $($field:ident $(: $value:expr)?),+ $(,)?) => { diff --git a/crates/blockifier/src/concurrency/versioned_state.rs b/crates/blockifier/src/concurrency/versioned_state.rs index d8c4cfd9ce..8b604bd849 100644 --- a/crates/blockifier/src/concurrency/versioned_state.rs +++ b/crates/blockifier/src/concurrency/versioned_state.rs @@ -21,6 +21,7 @@ const READ_ERR: &str = "Error: read value missing in the versioned storage"; /// Represents a versioned state used as shared state between a chunk of workers. /// This state facilitates concurrent operations. /// Reader functionality is injected through initial state. +#[derive(Debug)] pub struct VersionedState { initial_state: S, storage: VersionedStorage<(ContractAddress, StorageKey), StarkFelt>, @@ -42,12 +43,25 @@ impl VersionedState { } } - fn get_writes(&mut self, from_index: TxIndex) -> StateMaps { + fn get_writes_up_to_index(&mut self, tx_index: TxIndex) -> StateMaps { StateMaps { - storage: self.storage.get_writes_from_index(from_index), - nonces: self.nonces.get_writes_from_index(from_index), - class_hashes: self.class_hashes.get_writes_from_index(from_index), - compiled_class_hashes: self.compiled_class_hashes.get_writes_from_index(from_index), + storage: self.storage.get_writes_up_to_index(tx_index), + nonces: self.nonces.get_writes_up_to_index(tx_index), + class_hashes: self.class_hashes.get_writes_up_to_index(tx_index), + compiled_class_hashes: self.compiled_class_hashes.get_writes_up_to_index(tx_index), + // TODO(OriF, 01/07/2024): Update declared_contracts initial value. + declared_contracts: HashMap::new(), + } + } + + #[cfg(any(feature = "testing", test))] + pub fn get_writes_of_index(&self, tx_index: TxIndex) -> StateMaps { + StateMaps { + storage: self.storage.get_writes_of_index(tx_index), + nonces: self.nonces.get_writes_of_index(tx_index), + class_hashes: self.class_hashes.get_writes_of_index(tx_index), + compiled_class_hashes: self.compiled_class_hashes.get_writes_of_index(tx_index), + // TODO(OriF, 01/07/2024): Update declared_contracts initial value. declared_contracts: HashMap::new(), } } @@ -56,11 +70,11 @@ impl VersionedState { where T: StateReader, { - let writes = self.get_writes(from_index); + let writes = self.get_writes_up_to_index(from_index); parent_state.update_cache(writes); parent_state.update_contract_class_cache( - self.compiled_contract_classes.get_writes_from_index(from_index), + self.compiled_contract_classes.get_writes_up_to_index(from_index), ); } @@ -139,6 +153,30 @@ impl VersionedState { self.compiled_contract_classes.write(tx_index, key, value.clone()); } } + + fn delete_writes( + &mut self, + tx_index: TxIndex, + writes: &StateMaps, + class_hash_to_class: &ContractClassMapping, + ) { + for &key in writes.storage.keys() { + self.storage.delete_write(key, tx_index); + } + for &key in writes.nonces.keys() { + self.nonces.delete_write(key, tx_index); + } + for &key in writes.class_hashes.keys() { + self.class_hashes.delete_write(key, tx_index); + } + for &key in writes.compiled_class_hashes.keys() { + self.compiled_class_hashes.delete_write(key, tx_index); + } + // TODO(OriF, 01/07/2024): Add a for loop for `declared_contracts`. + for &key in class_hash_to_class.keys() { + self.compiled_contract_classes.delete_write(key, tx_index); + } + } } pub struct ThreadSafeVersionedState(Arc>>); @@ -177,6 +215,10 @@ impl VersionedStateProxy { pub fn apply_writes(&self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping) { self.state().apply_writes(self.tx_index, writes, class_hash_to_class) } + + pub fn delete_writes(&self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping) { + self.state().delete_writes(self.tx_index, writes, class_hash_to_class); + } } impl StateReader for VersionedStateProxy { diff --git a/crates/blockifier/src/concurrency/versioned_state_test.rs b/crates/blockifier/src/concurrency/versioned_state_test.rs index e61fe74855..63d51af494 100644 --- a/crates/blockifier/src/concurrency/versioned_state_test.rs +++ b/crates/blockifier/src/concurrency/versioned_state_test.rs @@ -9,12 +9,15 @@ use starknet_api::transaction::{Calldata, ContractAddressSalt, Fee, TransactionV use starknet_api::{calldata, class_hash, contract_address, patricia_key, stark_felt}; use crate::abi::abi_utils::{get_fee_token_var_address, get_storage_var_address}; -use crate::concurrency::test_utils::safe_versioned_state_for_testing; +use crate::concurrency::test_utils::{ + class_hash, contract_address, safe_versioned_state_for_testing, +}; use crate::concurrency::versioned_state::{ ThreadSafeVersionedState, VersionedState, VersionedStateProxy, }; +use crate::concurrency::TxIndex; use crate::context::BlockContext; -use crate::state::cached_state::{CachedState, StateMaps}; +use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps}; use crate::state::state_api::{State, StateReader}; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::deploy_account::deploy_account_tx; @@ -27,19 +30,6 @@ use crate::transaction::test_utils::l1_resource_bounds; use crate::transaction::transactions::ExecutableTransaction; use crate::{compiled_class_hash, deploy_account_tx_args, nonce, storage_key}; -const TEST_CONTRACT_ADDRESS: &str = "0x1"; -const TEST_CLASS_HASH: u8 = 27_u8; - -#[fixture] -pub fn contract_address() -> ContractAddress { - contract_address!(TEST_CONTRACT_ADDRESS) -} - -#[fixture] -pub fn class_hash() -> ClassHash { - class_hash!(TEST_CLASS_HASH) -} - #[fixture] pub fn safe_versioned_state( contract_address: ContractAddress, @@ -371,3 +361,118 @@ fn test_apply_writes_reexecute_scenario( // The class hash should be updated. assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0); } + +#[rstest] +fn test_delete_writes( + #[values(0, 1, 2)] tx_index_to_delete_writes: TxIndex, + safe_versioned_state: ThreadSafeVersionedState, +) { + let num_of_txs = 3; + let mut transactional_states: Vec>> = + (0..num_of_txs).map(|i| CachedState::from(safe_versioned_state.pin_version(i))).collect(); + // Setting 2 instances of the contract to ensure `delete_writes` removes information from + // multiple keys. Class hash values are not checked in this test. + let contract_addresses = [ + (contract_address!("0x100"), class_hash!(20_u8)), + (contract_address!("0x200"), class_hash!(21_u8)), + ]; + let feature_contract = FeatureContract::TestContract(CairoVersion::Cairo1); + for tx_state in transactional_states.iter_mut() { + // Modify the `cache` member of the CachedState. + for (contract_address, class_hash) in contract_addresses.iter() { + tx_state.set_class_hash_at(*contract_address, *class_hash).unwrap(); + } + // Modify the `class_hash_to_class` member of the CachedState. + tx_state + .set_contract_class(feature_contract.get_class_hash(), feature_contract.get_class()) + .unwrap(); + tx_state + .state + .apply_writes(&tx_state.cache.borrow().writes, &tx_state.class_hash_to_class.borrow()); + } + + transactional_states[tx_index_to_delete_writes].state.delete_writes( + &transactional_states[tx_index_to_delete_writes].cache.borrow().writes, + &transactional_states[tx_index_to_delete_writes].class_hash_to_class.borrow(), + ); + + for tx_index in 0..num_of_txs { + let should_be_empty = tx_index == tx_index_to_delete_writes; + assert_eq!( + safe_versioned_state + .0 + .lock() + .unwrap() + .get_writes_of_index(tx_index) + .class_hashes + .is_empty(), + should_be_empty + ); + + assert_eq!( + safe_versioned_state + .0 + .lock() + .unwrap() + .compiled_contract_classes + .get_writes_of_index(tx_index) + .is_empty(), + should_be_empty + ); + } +} + +#[rstest] +fn test_delete_writes_completeness( + safe_versioned_state: ThreadSafeVersionedState, +) { + let state_maps_writes = StateMaps { + nonces: HashMap::from([(contract_address!("0x1"), nonce!("0x1"))]), + class_hashes: HashMap::from([(contract_address!("0x1"), class_hash!("0x1"))]), + storage: HashMap::from([( + (contract_address!("0x1"), storage_key!("0x1")), + stark_felt!("0x1"), + )]), + compiled_class_hashes: HashMap::from([(class_hash!("0x1"), compiled_class_hash!("0x1"))]), + // TODO (OriF, 01/07/2024): Uncomment the following line and remove the line below it once + // `declared_contracts` mapping logic in StateMaps is complete. + // declared_contracts: HashMap::from([(class_hash!("0x1"), true)]), + declared_contracts: HashMap::default(), + }; + let feature_contract = FeatureContract::TestContract(CairoVersion::Cairo1); + let class_hash_to_class_writes = + HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_class())]); + + let tx_index = 0; + let versioned_state_proxy = safe_versioned_state.pin_version(tx_index); + + versioned_state_proxy.apply_writes(&state_maps_writes, &class_hash_to_class_writes); + assert_eq!( + safe_versioned_state.0.lock().unwrap().get_writes_of_index(tx_index), + state_maps_writes + ); + assert_eq!( + safe_versioned_state + .0 + .lock() + .unwrap() + .compiled_contract_classes + .get_writes_of_index(tx_index), + class_hash_to_class_writes + ); + + versioned_state_proxy.delete_writes(&state_maps_writes, &class_hash_to_class_writes); + assert_eq!( + safe_versioned_state.0.lock().unwrap().get_writes_of_index(tx_index), + StateMaps::default() + ); + assert_eq!( + safe_versioned_state + .0 + .lock() + .unwrap() + .compiled_contract_classes + .get_writes_of_index(tx_index), + ContractClassMapping::default() + ); +} diff --git a/crates/blockifier/src/concurrency/versioned_storage.rs b/crates/blockifier/src/concurrency/versioned_storage.rs index 61379dc79c..dba52e54c2 100644 --- a/crates/blockifier/src/concurrency/versioned_storage.rs +++ b/crates/blockifier/src/concurrency/versioned_storage.rs @@ -12,6 +12,7 @@ pub mod test; /// It is versioned in the sense that it holds a state of write operations done on it by /// different versions of executions. /// This allows maintaining the cells with the correct values in the context of each execution. +#[derive(Debug)] pub struct VersionedStorage where K: Clone + Copy + Eq + Hash + Debug, @@ -50,6 +51,16 @@ where cell.insert(tx_index, value); } + pub fn delete_write(&mut self, key: K, tx_index: TxIndex) { + self.writes + .get_mut(&key) + .expect( + "A 'delete_write' call must be preceded by a 'write' call with the corresponding \ + key", + ) + .remove(&tx_index); + } + /// This method inserts the provided key-value pair into the cached initial values map. /// It is typically used when reading a value that is not found in the versioned storage. In /// such a scenario, the value is retrieved from the initial storage and written to the @@ -58,13 +69,24 @@ where self.cached_initial_values.insert(key, value); } - pub(crate) fn get_writes_from_index(&self, from_index: TxIndex) -> HashMap { + pub(crate) fn get_writes_up_to_index(&self, index: TxIndex) -> HashMap { let mut writes = HashMap::default(); for (&key, cell) in self.writes.iter() { - if let Some(value) = cell.range(..=from_index).next_back() { + if let Some(value) = cell.range(..=index).next_back() { writes.insert(key, value.1.clone()); } } writes } + + #[cfg(any(feature = "testing", test))] + pub fn get_writes_of_index(&self, tx_index: TxIndex) -> HashMap { + let mut writes = HashMap::default(); + for (&key, cell) in self.writes.iter() { + if let Some(value) = cell.get(&tx_index) { + writes.insert(key, value.clone()); + } + } + writes + } } diff --git a/crates/blockifier/src/concurrency/versioned_storage_test.rs b/crates/blockifier/src/concurrency/versioned_storage_test.rs index 2ee7a950ce..34b6349073 100644 --- a/crates/blockifier/src/concurrency/versioned_storage_test.rs +++ b/crates/blockifier/src/concurrency/versioned_storage_test.rs @@ -1,7 +1,14 @@ +use std::collections::{BTreeMap, HashMap}; + use pretty_assertions::assert_eq; +use rstest::rstest; +use starknet_api::core::{ClassHash, ContractAddress}; +use crate::concurrency::test_utils::{class_hash, contract_address}; use crate::concurrency::versioned_storage::VersionedStorage; +use crate::concurrency::TxIndex; +// TODO(barak, 01/07/2024): Split into test_read() and test_write(). #[test] fn test_versioned_storage() { let mut storage = VersionedStorage::default(); @@ -37,3 +44,29 @@ fn test_versioned_storage() { // Test the write. assert_eq!(storage.read(50, 100).unwrap(), 194); } + +#[rstest] +fn test_delete_write( + contract_address: ContractAddress, + class_hash: ClassHash, + #[values(0, 1, 2)] tx_index_to_delete_writes: TxIndex, +) { + // TODO(barak, 01/07/2025): Create a macro versioned_storage!. + let num_of_txs = 3; + let mut versioned_storage = VersionedStorage { + cached_initial_values: HashMap::default(), + writes: HashMap::from([( + contract_address, + // Class hash values are not checked in this test. + BTreeMap::from_iter((0..num_of_txs).map(|i| (i, class_hash))), + )]), + }; + for tx_index in 0..num_of_txs { + let should_contain_tx_index_writes = tx_index != tx_index_to_delete_writes; + versioned_storage.delete_write(contract_address, tx_index_to_delete_writes); + assert_eq!( + versioned_storage.writes.get(&contract_address).unwrap().contains_key(&tx_index), + should_contain_tx_index_writes + ) + } +}