diff --git a/crates/blockifier/src/blockifier/stateful_validator.rs b/crates/blockifier/src/blockifier/stateful_validator.rs index bb89e3d5fc..c88e16eab2 100644 --- a/crates/blockifier/src/blockifier/stateful_validator.rs +++ b/crates/blockifier/src/blockifier/stateful_validator.rs @@ -15,6 +15,7 @@ use crate::fee::fee_checks::PostValidationReport; use crate::state::cached_state::CachedState; use crate::state::errors::StateError; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::account_transaction::AccountTransaction; use crate::transaction::errors::{TransactionExecutionError, TransactionPreValidationError}; use crate::transaction::transaction_execution::Transaction; @@ -39,12 +40,12 @@ pub enum StatefulValidatorError { pub type StatefulValidatorResult = Result; /// Manages state related transaction validations for pre-execution flows. -pub struct StatefulValidator { - tx_executor: TransactionExecutor, +pub struct StatefulValidator { + tx_executor: TransactionExecutor, } -impl StatefulValidator { - pub fn create(state: CachedState, block_context: BlockContext) -> Self { +impl StatefulValidator { + pub fn create(state: CachedState, block_context: BlockContext) -> Self { let tx_executor = TransactionExecutor::new(state, block_context, TransactionExecutorConfig::default()); Self { tx_executor } diff --git a/crates/blockifier/src/blockifier/transaction_executor.rs b/crates/blockifier/src/blockifier/transaction_executor.rs index 4cb04c6341..cf489084bd 100644 --- a/crates/blockifier/src/blockifier/transaction_executor.rs +++ b/crates/blockifier/src/blockifier/transaction_executor.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; #[cfg(feature = "concurrency")] use std::panic::{self, catch_unwind, AssertUnwindSafe}; #[cfg(feature = "concurrency")] @@ -18,6 +19,7 @@ use crate::context::BlockContext; use crate::state::cached_state::{CachedState, CommitmentStateDiff, TransactionalState}; use crate::state::errors::StateError; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::errors::TransactionExecutionError; use crate::transaction::objects::TransactionExecutionInfo; use crate::transaction::transaction_execution::Transaction; @@ -43,7 +45,7 @@ pub type TransactionExecutorResult = Result; pub type VisitedSegmentsMapping = Vec<(ClassHash, Vec)>; // TODO(Gilad): make this hold TransactionContext instead of BlockContext. -pub struct TransactionExecutor { +pub struct TransactionExecutor { pub block_context: BlockContext, pub bouncer: Bouncer, // Note: this config must not affect the execution result (e.g. state diff and traces). @@ -54,12 +56,12 @@ pub struct TransactionExecutor { // block state to the worker executor - operating at the chunk level - and gets it back after // committing the chunk. The block state is wrapped with an Option<_> to allow setting it to // `None` while it is moved to the worker executor. - pub block_state: Option>, + pub block_state: Option>, } -impl TransactionExecutor { +impl TransactionExecutor { pub fn new( - block_state: CachedState, + block_state: CachedState, block_context: BlockContext, config: TransactionExecutorConfig, ) -> Self { @@ -85,9 +87,10 @@ impl TransactionExecutor { &mut self, tx: &Transaction, ) -> TransactionExecutorResult { - let mut transactional_state = TransactionalState::create_transactional( - self.block_state.as_mut().expect(BLOCK_STATE_ACCESS_ERR), - ); + let mut transactional_state: TransactionalState<'_, _, V> = + TransactionalState::create_transactional( + self.block_state.as_mut().expect(BLOCK_STATE_ACCESS_ERR), + ); // Executing a single transaction cannot be done in a concurrent mode. let execution_flags = ExecutionFlags { charge_fee: true, validate: true, concurrency_mode: false }; @@ -157,7 +160,8 @@ impl TransactionExecutor { .as_ref() .expect(BLOCK_STATE_ACCESS_ERR) .get_compiled_contract_class(*class_hash)?; - Ok((*class_hash, contract_class.get_visited_segments(class_visited_pcs)?)) + let class_visited_pcs = V::to_set(class_visited_pcs.clone()); + Ok((*class_hash, contract_class.get_visited_segments(&class_visited_pcs)?)) }) .collect::>()?; @@ -170,7 +174,7 @@ impl TransactionExecutor { } } -impl TransactionExecutor { +impl TransactionExecutor { /// Executes the given transactions on the state maintained by the executor. /// Stops if and when there is no more room in the block, and returns the executed transactions' /// results. @@ -219,7 +223,7 @@ impl TransactionExecutor { chunk: &[Transaction], ) -> Vec> { use crate::concurrency::utils::AbortIfPanic; - use crate::state::cached_state::VisitedPcs; + use crate::concurrency::worker_logic::ExecutionTaskOutput; let block_state = self.block_state.take().expect("The block state should be `Some`."); @@ -263,20 +267,20 @@ impl TransactionExecutor { let n_committed_txs = worker_executor.scheduler.get_n_committed_txs(); let mut tx_execution_results = Vec::new(); - let mut visited_pcs: VisitedPcs = VisitedPcs::new(); + let mut visited_pcs: V = V::new(); for execution_output in worker_executor.execution_outputs.iter() { if tx_execution_results.len() >= n_committed_txs { break; } - let locked_execution_output = execution_output + let locked_execution_output: ExecutionTaskOutput = execution_output .lock() .expect("Failed to lock execution output.") .take() .expect("Output must be ready."); tx_execution_results .push(locked_execution_output.result.map_err(TransactionExecutorError::from)); - for (class_hash, class_visited_pcs) in locked_execution_output.visited_pcs { - visited_pcs.entry(class_hash).or_default().extend(class_visited_pcs); + for (class_hash, class_visited_pcs) in locked_execution_output.visited_pcs.iter() { + visited_pcs.extend(class_hash, class_visited_pcs); } } diff --git a/crates/blockifier/src/blockifier/transaction_executor_test.rs b/crates/blockifier/src/blockifier/transaction_executor_test.rs index b0139de8ba..ae7740164d 100644 --- a/crates/blockifier/src/blockifier/transaction_executor_test.rs +++ b/crates/blockifier/src/blockifier/transaction_executor_test.rs @@ -13,6 +13,7 @@ use crate::bouncer::{Bouncer, BouncerWeights}; use crate::context::BlockContext; use crate::state::cached_state::CachedState; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::VisitedPcs; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::deploy_account::deploy_account_tx; @@ -30,8 +31,8 @@ use crate::transaction::transaction_execution::Transaction; use crate::transaction::transactions::L1HandlerTransaction; use crate::{declare_tx_args, deploy_account_tx_args, invoke_tx_args, nonce}; -fn tx_executor_test_body( - state: CachedState, +fn tx_executor_test_body( + state: CachedState, block_context: BlockContext, tx: Transaction, expected_bouncer_weights: BouncerWeights, diff --git a/crates/blockifier/src/bouncer_test.rs b/crates/blockifier/src/bouncer_test.rs index 9976ca0f13..d0a744704a 100644 --- a/crates/blockifier/src/bouncer_test.rs +++ b/crates/blockifier/src/bouncer_test.rs @@ -13,6 +13,7 @@ use crate::bouncer::{verify_tx_weights_in_bounds, Bouncer, BouncerWeights, Built use crate::context::BlockContext; use crate::execution::call_info::ExecutionSummary; use crate::state::cached_state::{StateChangesKeys, TransactionalState}; +use crate::state::visited_pcs::VisitedPcsSet; use crate::storage_key; use crate::test_utils::initial_test_state::test_state; use crate::transaction::errors::TransactionExecutionError; @@ -187,7 +188,8 @@ fn test_bouncer_try_update( use crate::transaction::objects::TransactionResources; let state = &mut test_state(&BlockContext::create_for_account_testing().chain_info, 0, &[]); - let mut transactional_state = TransactionalState::create_transactional(state); + let mut transactional_state: TransactionalState<'_, _, VisitedPcsSet> = + TransactionalState::create_transactional(state); // Setup the bouncer. let block_max_capacity = BouncerWeights { diff --git a/crates/blockifier/src/concurrency/fee_utils.rs b/crates/blockifier/src/concurrency/fee_utils.rs index b9ad04942e..a5386dae73 100644 --- a/crates/blockifier/src/concurrency/fee_utils.rs +++ b/crates/blockifier/src/concurrency/fee_utils.rs @@ -10,6 +10,7 @@ use crate::execution::call_info::CallInfo; use crate::fee::fee_utils::get_sequencer_balance_keys; use crate::state::cached_state::{ContractClassMapping, StateMaps}; use crate::state::state_api::UpdatableState; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::objects::TransactionExecutionInfo; #[cfg(test)] @@ -22,10 +23,10 @@ mod test; pub(crate) const STORAGE_READ_SEQUENCER_BALANCE_INDICES: (usize, usize) = (2, 3); // Completes the fee transfer flow if needed (if the transfer was made in concurrent mode). -pub fn complete_fee_transfer_flow( +pub fn complete_fee_transfer_flow>( tx_context: &TransactionContext, tx_execution_info: &mut TransactionExecutionInfo, - state: &mut impl UpdatableState, + state: &mut U, ) { if tx_context.is_sequencer_the_sender() { // When the sequencer is the sender, we use the sequential (full) fee transfer. @@ -93,9 +94,9 @@ pub fn fill_sequencer_balance_reads( storage_read_values[high_index] = high; } -pub fn add_fee_to_sequencer_balance( +pub fn add_fee_to_sequencer_balance>( fee_token_address: ContractAddress, - state: &mut impl UpdatableState, + state: &mut U, actual_fee: Fee, block_context: &BlockContext, sequencer_balance: (Felt, Felt), @@ -120,5 +121,5 @@ pub fn add_fee_to_sequencer_balance( ]), ..StateMaps::default() }; - state.apply_writes(&writes, &ContractClassMapping::default(), &HashMap::default()); + state.apply_writes(&writes, &ContractClassMapping::default(), &V::default()); } diff --git a/crates/blockifier/src/concurrency/flow_test.rs b/crates/blockifier/src/concurrency/flow_test.rs index c89644940b..4e6a7e1cac 100644 --- a/crates/blockifier/src/concurrency/flow_test.rs +++ b/crates/blockifier/src/concurrency/flow_test.rs @@ -9,9 +9,10 @@ use starknet_api::{contract_address, felt, patricia_key}; use crate::abi::sierra_types::{SierraType, SierraU128}; use crate::concurrency::scheduler::{Scheduler, Task, TransactionStatus}; use crate::concurrency::test_utils::{safe_versioned_state_for_testing, DEFAULT_CHUNK_SIZE}; -use crate::concurrency::versioned_state::ThreadSafeVersionedState; +use crate::concurrency::versioned_state::{ThreadSafeVersionedState, VersionedStateProxy}; use crate::state::cached_state::{CachedState, ContractClassMapping, StateMaps}; use crate::state::state_api::UpdatableState; +use crate::state::visited_pcs::VisitedPcsSet; use crate::storage_key; use crate::test_utils::dict_state_reader::DictStateReader; @@ -27,6 +28,9 @@ fn scheduler_flow_test( // transactions with multiple threads, where every transaction depends on its predecessor. Each // transaction sequentially advances a counter by reading the previous value and bumping it by // 1. + + use crate::concurrency::versioned_state::VersionedStateProxy; + use crate::state::visited_pcs::VisitedPcsSet; let scheduler = Arc::new(Scheduler::new(DEFAULT_CHUNK_SIZE)); let versioned_state = safe_versioned_state_for_testing(CachedState::from(DictStateReader::default())); @@ -53,7 +57,7 @@ fn scheduler_flow_test( state_proxy.apply_writes( &new_writes, &ContractClassMapping::default(), - &HashMap::default(), + &VisitedPcsSet::default(), ); scheduler.finish_execution_during_commit(tx_index); } @@ -66,13 +70,14 @@ fn scheduler_flow_test( versioned_state.pin_version(tx_index).apply_writes( &writes, &ContractClassMapping::default(), - &HashMap::default(), + &VisitedPcsSet::default(), ); scheduler.finish_execution(tx_index); Task::AskForTask } Task::ValidationTask(tx_index) => { - let state_proxy = versioned_state.pin_version(tx_index); + let state_proxy: VersionedStateProxy<_, VisitedPcsSet> = + versioned_state.pin_version(tx_index); let (reads, writes) = get_reads_writes_for(Task::ValidationTask(tx_index), &versioned_state); let read_set_valid = state_proxy.validate_reads(&reads); @@ -120,11 +125,11 @@ fn scheduler_flow_test( fn get_reads_writes_for( task: Task, - versioned_state: &ThreadSafeVersionedState>, + versioned_state: &ThreadSafeVersionedState>, ) -> (StateMaps, StateMaps) { match task { Task::ExecutionTask(tx_index) => { - let state_proxy = match tx_index { + let state_proxy: VersionedStateProxy<_, VisitedPcsSet> = match tx_index { 0 => { return ( state_maps_with_single_storage_entry(0), @@ -146,7 +151,8 @@ fn get_reads_writes_for( ) } Task::ValidationTask(tx_index) => { - let state_proxy = versioned_state.pin_version(tx_index); + let state_proxy: VersionedStateProxy<_, VisitedPcsSet> = + versioned_state.pin_version(tx_index); let tx_written_value = SierraU128::from_storage( &state_proxy, &contract_address!(CONTRACT_ADDRESS), diff --git a/crates/blockifier/src/concurrency/test_utils.rs b/crates/blockifier/src/concurrency/test_utils.rs index 87722b1171..a19cc98bef 100644 --- a/crates/blockifier/src/concurrency/test_utils.rs +++ b/crates/blockifier/src/concurrency/test_utils.rs @@ -7,6 +7,7 @@ use crate::context::BlockContext; use crate::execution::call_info::CallInfo; use crate::state::cached_state::{CachedState, TransactionalState}; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::{VisitedPcs, VisitedPcsSet}; use crate::test_utils::dict_state_reader::DictStateReader; use crate::transaction::account_transaction::AccountTransaction; use crate::transaction::transactions::{ExecutableTransaction, ExecutionFlags}; @@ -61,21 +62,22 @@ macro_rules! default_scheduler { // TODO(meshi, 01/06/2024): Consider making this a macro. pub fn safe_versioned_state_for_testing( - block_state: CachedState, -) -> ThreadSafeVersionedState> { + block_state: CachedState, +) -> ThreadSafeVersionedState> { ThreadSafeVersionedState::new(VersionedState::new(block_state)) } // Utils. // Note: this function does not mutate the state. -pub fn create_fee_transfer_call_info( - state: &mut CachedState, +pub fn create_fee_transfer_call_info( + state: &mut CachedState, account_tx: &AccountTransaction, concurrency_mode: bool, ) -> CallInfo { let block_context = BlockContext::create_for_account_testing(); - let mut transactional_state = TransactionalState::create_transactional(state); + let mut transactional_state: TransactionalState<'_, _, V> = + TransactionalState::create_transactional(state); let execution_flags = ExecutionFlags { charge_fee: true, validate: true, concurrency_mode }; let execution_info = account_tx.execute_raw(&mut transactional_state, &block_context, execution_flags).unwrap(); diff --git a/crates/blockifier/src/concurrency/versioned_state.rs b/crates/blockifier/src/concurrency/versioned_state.rs index fe80fa38e3..d2e8692484 100644 --- a/crates/blockifier/src/concurrency/versioned_state.rs +++ b/crates/blockifier/src/concurrency/versioned_state.rs @@ -1,3 +1,4 @@ +use std::marker::PhantomData; use std::sync::{Arc, Mutex, MutexGuard}; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; @@ -7,9 +8,10 @@ use starknet_types_core::felt::Felt; use crate::concurrency::versioned_storage::VersionedStorage; use crate::concurrency::TxIndex; use crate::execution::contract_class::ContractClass; -use crate::state::cached_state::{ContractClassMapping, StateMaps, VisitedPcs}; +use crate::state::cached_state::{ContractClassMapping, StateMaps}; use crate::state::errors::StateError; use crate::state::state_api::{StateReader, StateResult, UpdatableState}; +use crate::state::visited_pcs::VisitedPcs; #[cfg(test)] #[path = "versioned_state_test.rs"] @@ -197,11 +199,11 @@ impl VersionedState { } } -impl VersionedState { +impl> VersionedState { pub fn commit_chunk_and_recover_block_state( mut self, n_committed_txs: usize, - visited_pcs: VisitedPcs, + visited_pcs: V, ) -> U { if n_committed_txs == 0 { return self.into_initial_state(); @@ -228,8 +230,8 @@ impl ThreadSafeVersionedState { ThreadSafeVersionedState(Arc::new(Mutex::new(versioned_state))) } - pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy { - VersionedStateProxy { tx_index, state: self.0.clone() } + pub fn pin_version(&self, tx_index: TxIndex) -> VersionedStateProxy { + VersionedStateProxy { tx_index, state: self.0.clone(), _marker: PhantomData } } pub fn into_inner_state(self) -> VersionedState { @@ -251,12 +253,13 @@ impl Clone for ThreadSafeVersionedState { } } -pub struct VersionedStateProxy { +pub struct VersionedStateProxy { pub tx_index: TxIndex, pub state: Arc>>, + _marker: PhantomData, } -impl VersionedStateProxy { +impl VersionedStateProxy { fn state(&self) -> LockedVersionedState<'_, S> { self.state.lock().expect("Failed to acquire state lock.") } @@ -271,18 +274,20 @@ impl VersionedStateProxy { } // TODO(Noa, 15/5/24): Consider using visited_pcs. -impl UpdatableState for VersionedStateProxy { +impl UpdatableState for VersionedStateProxy { + type T = V; + fn apply_writes( &mut self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping, - _visited_pcs: &VisitedPcs, + _visited_pcs: &V, ) { self.state().apply_writes(self.tx_index, writes, class_hash_to_class) } } -impl StateReader for VersionedStateProxy { +impl StateReader for VersionedStateProxy { fn get_storage_at( &self, contract_address: ContractAddress, diff --git a/crates/blockifier/src/concurrency/versioned_state_test.rs b/crates/blockifier/src/concurrency/versioned_state_test.rs index ab99698b4e..bec9541a10 100644 --- a/crates/blockifier/src/concurrency/versioned_state_test.rs +++ b/crates/blockifier/src/concurrency/versioned_state_test.rs @@ -24,6 +24,7 @@ use crate::state::cached_state::{ }; use crate::state::errors::StateError; use crate::state::state_api::{State, StateReader, UpdatableState}; +use crate::state::visited_pcs::{VisitedPcs, VisitedPcsSet}; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::deploy_account::deploy_account_tx; use crate::test_utils::dict_state_reader::DictStateReader; @@ -39,7 +40,7 @@ use crate::{compiled_class_hash, deploy_account_tx_args, nonce, storage_key}; pub fn safe_versioned_state( contract_address: ContractAddress, class_hash: ClassHash, -) -> ThreadSafeVersionedState> { +) -> ThreadSafeVersionedState> { let init_state = DictStateReader { address_to_class_hash: HashMap::from([(contract_address, class_hash)]), ..Default::default() @@ -72,8 +73,9 @@ fn test_versioned_state_proxy() { let versioned_state = Arc::new(Mutex::new(VersionedState::new(cached_state))); let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state)); - let versioned_state_proxys: Vec>> = - (0..20).map(|i| safe_versioned_state.pin_version(i)).collect(); + let versioned_state_proxys: Vec< + VersionedStateProxy, VisitedPcsSet>, + > = (0..20).map(|i| safe_versioned_state.pin_version(i)).collect(); // Read initial data assert_eq!(versioned_state_proxys[5].get_nonce_at(contract_address).unwrap(), nonce); @@ -208,10 +210,14 @@ fn test_run_parallel_txs(max_resource_bounds: ResourceBoundsMapping) { )))); let safe_versioned_state = ThreadSafeVersionedState(Arc::clone(&versioned_state)); - let mut versioned_state_proxy_1 = safe_versioned_state.pin_version(1); - let mut state_1 = TransactionalState::create_transactional(&mut versioned_state_proxy_1); - let mut versioned_state_proxy_2 = safe_versioned_state.pin_version(2); - let mut state_2 = TransactionalState::create_transactional(&mut versioned_state_proxy_2); + let mut versioned_state_proxy_1: VersionedStateProxy<_, VisitedPcsSet> = + safe_versioned_state.pin_version(1); + let mut state_1: TransactionalState<'_, _, VisitedPcsSet> = + TransactionalState::create_transactional(&mut versioned_state_proxy_1); + let mut versioned_state_proxy_2: VersionedStateProxy<_, VisitedPcsSet> = + safe_versioned_state.pin_version(2); + let mut state_2: TransactionalState<'_, _, VisitedPcsSet> = + TransactionalState::create_transactional(&mut versioned_state_proxy_2); // Prepare transactions let deploy_account_tx_1 = deploy_account_tx( @@ -248,10 +254,12 @@ fn test_run_parallel_txs(max_resource_bounds: ResourceBoundsMapping) { let block_context_1 = block_context.clone(); let block_context_2 = block_context.clone(); + // Execute transactions thread::scope(|s| { s.spawn(move || { let result = account_tx_1.execute(&mut state_1, &block_context_1, true, true); + assert_eq!(result.is_err(), enforce_fee); }); s.spawn(move || { @@ -276,15 +284,19 @@ fn test_run_parallel_txs(max_resource_bounds: ResourceBoundsMapping) { fn test_validate_reads( contract_address: ContractAddress, class_hash: ClassHash, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let storage_key = storage_key!(0x10_u8); - let mut version_state_proxy = safe_versioned_state.pin_version(1); - let transactional_state = TransactionalState::create_transactional(&mut version_state_proxy); + let mut version_state_proxy: VersionedStateProxy<_, VisitedPcsSet> = + safe_versioned_state.pin_version(1); + let transactional_state: TransactionalState<'_, _, VisitedPcsSet> = + TransactionalState::create_transactional(&mut version_state_proxy); // Validating tx index 0 always succeeds. - assert!(safe_versioned_state.pin_version(0).validate_reads(&StateMaps::default())); + assert!( + safe_versioned_state.pin_version::(0).validate_reads(&StateMaps::default()) + ); assert!(transactional_state.cache.borrow().initial_reads.storage.is_empty()); transactional_state.get_storage_at(contract_address, storage_key).unwrap(); @@ -313,7 +325,7 @@ fn test_validate_reads( assert!( safe_versioned_state - .pin_version(1) + .pin_version::(1) .validate_reads(&transactional_state.cache.borrow().initial_reads) ); } @@ -366,16 +378,17 @@ fn test_validate_reads( fn test_false_validate_reads( #[case] tx_1_reads: StateMaps, #[case] tx_0_writes: StateMaps, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { - let version_state_proxy = safe_versioned_state.pin_version(0); + let version_state_proxy: VersionedStateProxy<_, VisitedPcsSet> = + safe_versioned_state.pin_version(0); version_state_proxy.state().apply_writes(0, &tx_0_writes, &HashMap::default()); - assert!(!safe_versioned_state.pin_version(1).validate_reads(&tx_1_reads)); + assert!(!safe_versioned_state.pin_version::(1).validate_reads(&tx_1_reads)); } #[rstest] fn test_false_validate_reads_declared_contracts( - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let tx_1_reads = StateMaps { declared_contracts: HashMap::from([(class_hash!(1_u8), false)]), @@ -385,24 +398,24 @@ fn test_false_validate_reads_declared_contracts( declared_contracts: HashMap::from([(class_hash!(1_u8), true)]), ..Default::default() }; - let version_state_proxy = safe_versioned_state.pin_version(0); + let version_state_proxy: VersionedStateProxy<_, VisitedPcsSet> = + safe_versioned_state.pin_version(0); let compiled_contract_calss = FeatureContract::TestContract(CairoVersion::Cairo1).get_class(); let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_calss)]); version_state_proxy.state().apply_writes(0, &tx_0_writes, &class_hash_to_class); - assert!(!safe_versioned_state.pin_version(1).validate_reads(&tx_1_reads)); + assert!(!safe_versioned_state.pin_version::(1).validate_reads(&tx_1_reads)); } #[rstest] fn test_apply_writes( contract_address: ContractAddress, class_hash: ClassHash, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { - let mut versioned_proxy_states: Vec>> = + let mut versioned_proxy_states: Vec> = (0..2).map(|i| safe_versioned_state.pin_version(i)).collect(); - let mut transactional_states: Vec< - TransactionalState<'_, VersionedStateProxy>>, - > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); + let mut transactional_states: Vec> = + versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); // Transaction 0 class hash. let class_hash_0 = class_hash!(76_u8); @@ -419,7 +432,7 @@ fn test_apply_writes( safe_versioned_state.pin_version(0).apply_writes( &transactional_states[0].cache.borrow().writes, &transactional_states[0].class_hash_to_class.borrow().clone(), - &HashMap::default(), + &VisitedPcsSet::default(), ); assert!(transactional_states[1].get_class_hash_at(contract_address).unwrap() == class_hash_0); assert!( @@ -432,13 +445,12 @@ fn test_apply_writes( fn test_apply_writes_reexecute_scenario( contract_address: ContractAddress, class_hash: ClassHash, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { - let mut versioned_proxy_states: Vec>> = + let mut versioned_proxy_states: Vec> = (0..2).map(|i| safe_versioned_state.pin_version(i)).collect(); - let mut transactional_states: Vec< - TransactionalState<'_, VersionedStateProxy>>, - > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); + let mut transactional_states: Vec> = + versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); // Transaction 0 class hash. let class_hash_0 = class_hash!(76_u8); @@ -451,7 +463,7 @@ fn test_apply_writes_reexecute_scenario( safe_versioned_state.pin_version(0).apply_writes( &transactional_states[0].cache.borrow().writes, &transactional_states[0].class_hash_to_class.borrow().clone(), - &HashMap::default(), + &VisitedPcsSet::default(), ); // Although transaction 0 wrote to the shared state, version 1 needs to be re-executed to see // the new value (its read value has already been cached). @@ -468,14 +480,13 @@ fn test_apply_writes_reexecute_scenario( #[rstest] fn test_delete_writes( #[values(0, 1, 2)] tx_index_to_delete_writes: TxIndex, - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let num_of_txs = 3; - let mut versioned_proxy_states: Vec>> = + let mut versioned_proxy_states: Vec> = (0..num_of_txs).map(|i| safe_versioned_state.pin_version(i)).collect(); - let mut transactional_states: Vec< - TransactionalState<'_, VersionedStateProxy>>, - > = versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); + let mut transactional_states: Vec> = + versioned_proxy_states.iter_mut().map(TransactionalState::create_transactional).collect(); // Setting 2 instances of the contract to ensure `delete_writes` removes information from // multiple keys. Class hash values are not checked in this test. @@ -496,11 +507,11 @@ fn test_delete_writes( safe_versioned_state.pin_version(i).apply_writes( &tx_state.cache.borrow().writes, &tx_state.class_hash_to_class.borrow(), - &HashMap::default(), + &VisitedPcsSet::default(), ); } - safe_versioned_state.pin_version(tx_index_to_delete_writes).delete_writes( + safe_versioned_state.pin_version::(tx_index_to_delete_writes).delete_writes( &transactional_states[tx_index_to_delete_writes].cache.borrow().writes, &transactional_states[tx_index_to_delete_writes].class_hash_to_class.borrow(), ); @@ -533,7 +544,7 @@ fn test_delete_writes( #[rstest] fn test_delete_writes_completeness( - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let feature_contract = FeatureContract::TestContract(CairoVersion::Cairo1); let state_maps_writes = StateMaps { @@ -558,7 +569,7 @@ fn test_delete_writes_completeness( versioned_state_proxy.apply_writes( &state_maps_writes, &class_hash_to_class_writes, - &HashMap::default(), + &VisitedPcsSet::default(), ); assert_eq!( safe_versioned_state.0.lock().unwrap().get_writes_of_index(tx_index), @@ -592,15 +603,16 @@ fn test_delete_writes_completeness( #[rstest] fn test_versioned_proxy_state_flow( - safe_versioned_state: ThreadSafeVersionedState>, + safe_versioned_state: ThreadSafeVersionedState>, ) { let contract_address = contract_address!("0x1"); let class_hash = ClassHash(felt!(27_u8)); - let mut versioned_proxy_states: Vec>> = + let mut versioned_proxy_states: Vec> = (0..4).map(|i| safe_versioned_state.pin_version(i)).collect(); - let mut transactional_states = Vec::with_capacity(4); + let mut transactional_states: Vec> = + Vec::with_capacity(4); for proxy_state in &mut versioned_proxy_states { transactional_states.push(TransactionalState::create_transactional(proxy_state)); } @@ -635,7 +647,7 @@ fn test_versioned_proxy_state_flow( } let modified_block_state = safe_versioned_state .into_inner_state() - .commit_chunk_and_recover_block_state(4, HashMap::new()); + .commit_chunk_and_recover_block_state(4, VisitedPcsSet::new()); assert!(modified_block_state.get_class_hash_at(contract_address).unwrap() == class_hash_3); assert!( diff --git a/crates/blockifier/src/concurrency/worker_logic.rs b/crates/blockifier/src/concurrency/worker_logic.rs index 150af07e9e..859a69d888 100644 --- a/crates/blockifier/src/concurrency/worker_logic.rs +++ b/crates/blockifier/src/concurrency/worker_logic.rs @@ -14,9 +14,10 @@ use crate::concurrency::versioned_state::ThreadSafeVersionedState; use crate::concurrency::TxIndex; use crate::context::BlockContext; use crate::state::cached_state::{ - ContractClassMapping, StateChanges, StateMaps, TransactionalState, VisitedPcs, + ContractClassMapping, StateChanges, StateMaps, TransactionalState, }; use crate::state::state_api::{StateReader, UpdatableState}; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::objects::{TransactionExecutionInfo, TransactionExecutionResult}; use crate::transaction::transaction_execution::Transaction; use crate::transaction::transactions::{ExecutableTransaction, ExecutionFlags}; @@ -28,23 +29,23 @@ pub mod test; const EXECUTION_OUTPUTS_UNWRAP_ERROR: &str = "Execution task outputs should not be None."; #[derive(Debug)] -pub struct ExecutionTaskOutput { +pub struct ExecutionTaskOutput { pub reads: StateMaps, pub writes: StateMaps, pub contract_classes: ContractClassMapping, - pub visited_pcs: VisitedPcs, + pub visited_pcs: V, pub result: TransactionExecutionResult, } -pub struct WorkerExecutor<'a, S: StateReader> { +pub struct WorkerExecutor<'a, S: StateReader, V: VisitedPcs> { pub scheduler: Scheduler, pub state: ThreadSafeVersionedState, pub chunk: &'a [Transaction], - pub execution_outputs: Box<[Mutex>]>, + pub execution_outputs: Box<[Mutex>>]>, pub block_context: &'a BlockContext, pub bouncer: Mutex<&'a mut Bouncer>, } -impl<'a, S: StateReader> WorkerExecutor<'a, S> { +impl<'a, S: StateReader, V: VisitedPcs + Default + Debug> WorkerExecutor<'a, S, V> { pub fn new( state: ThreadSafeVersionedState, chunk: &'a [Transaction], @@ -135,7 +136,7 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { self.state.pin_version(tx_index).apply_writes( &transactional_state.cache.borrow().writes, &transactional_state.class_hash_to_class.borrow(), - &HashMap::default(), + &V::default(), ); } @@ -145,7 +146,7 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { // In case of a failed transaction, we don't record its writes and visited pcs. let (writes, contract_classes, visited_pcs) = match execution_result { Ok(_) => (tx_reads_writes.writes, class_hash_to_class, transactional_state.visited_pcs), - Err(_) => (StateMaps::default(), HashMap::default(), HashMap::default()), + Err(_) => (StateMaps::default(), HashMap::default(), V::default()), }; let mut execution_output = lock_mutex_in_array(&self.execution_outputs, tx_index); *execution_output = Some(ExecutionTaskOutput { @@ -158,7 +159,7 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { } fn validate(&self, tx_index: TxIndex) -> Task { - let tx_versioned_state = self.state.pin_version(tx_index); + let tx_versioned_state = self.state.pin_version::(tx_index); let execution_output = lock_mutex_in_array(&self.execution_outputs, tx_index); let execution_output = execution_output.as_ref().expect(EXECUTION_OUTPUTS_UNWRAP_ERROR); let reads = &execution_output.reads; @@ -191,7 +192,7 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { let execution_output_ref = execution_output.as_ref().expect(EXECUTION_OUTPUTS_UNWRAP_ERROR); let reads = &execution_output_ref.reads; - let mut tx_versioned_state = self.state.pin_version(tx_index); + let mut tx_versioned_state = self.state.pin_version::(tx_index); let reads_valid = tx_versioned_state.validate_reads(reads); // First, re-validate the transaction. @@ -258,12 +259,8 @@ impl<'a, S: StateReader> WorkerExecutor<'a, S> { } } -impl<'a, U: UpdatableState> WorkerExecutor<'a, U> { - pub fn commit_chunk_and_recover_block_state( - self, - n_committed_txs: usize, - visited_pcs: VisitedPcs, - ) -> U { +impl<'a, V: VisitedPcs, U: UpdatableState> WorkerExecutor<'a, U, V> { + pub fn commit_chunk_and_recover_block_state(self, n_committed_txs: usize, visited_pcs: V) -> U { self.state .into_inner_state() .commit_chunk_and_recover_block_state(n_committed_txs, visited_pcs) diff --git a/crates/blockifier/src/concurrency/worker_logic_test.rs b/crates/blockifier/src/concurrency/worker_logic_test.rs index 1f28b1ee4f..da61c09aab 100644 --- a/crates/blockifier/src/concurrency/worker_logic_test.rs +++ b/crates/blockifier/src/concurrency/worker_logic_test.rs @@ -22,6 +22,7 @@ use crate::context::{BlockContext, TransactionContext}; use crate::fee::fee_utils::get_sequencer_balance_keys; use crate::state::cached_state::StateMaps; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::initial_test_state::test_state; @@ -61,7 +62,7 @@ fn verify_sequencer_balance_update( expected_sequencer_balance_low: u128, ) { let TransactionContext { block_context, tx_info } = tx_context; - let tx_version_state = state.pin_version(tx_index); + let tx_version_state = state.pin_version::(tx_index); let (sequencer_balance_key_low, sequencer_balance_key_high) = get_sequencer_balance_keys(block_context); for (expected_balance, storage_key) in [ @@ -105,7 +106,7 @@ pub fn test_commit_tx() { let cached_state = test_state(&block_context.chain_info, BALANCE, &[(account, 1), (test_contract, 1)]); let versioned_state = safe_versioned_state_for_testing(cached_state); - let executor = + let executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new(versioned_state, &txs, &block_context, Mutex::new(&mut bouncer)); // Execute transactions. @@ -205,14 +206,14 @@ fn test_commit_tx_when_sender_is_sequencer() { let state = test_state(&block_context.chain_info, BALANCE, &[(account, 1), (test_contract, 1)]); let versioned_state = safe_versioned_state_for_testing(state); - let executor = WorkerExecutor::new( + let executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new( versioned_state, &sequencer_tx, &block_context, Mutex::new(&mut bouncer), ); let tx_index = 0; - let tx_versioned_state = executor.state.pin_version(tx_index); + let tx_versioned_state = executor.state.pin_version::(tx_index); // Execute and save the execution result. executor.execute_tx(tx_index); @@ -312,7 +313,7 @@ fn test_worker_execute(max_resource_bounds: ResourceBoundsMapping) { .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = WorkerExecutor::new( + let worker_executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new( safe_versioned_state.clone(), &txs, &block_context, @@ -330,7 +331,7 @@ fn test_worker_execute(max_resource_bounds: ResourceBoundsMapping) { // Read a write made by the transaction. assert_eq!( safe_versioned_state - .pin_version(tx_index) + .pin_version::(tx_index) .get_storage_at(test_contract_address, storage_key) .unwrap(), storage_value @@ -383,14 +384,17 @@ fn test_worker_execute(max_resource_bounds: ResourceBoundsMapping) { assert_eq!(execution_output.writes, writes); assert_eq!(execution_output.reads, reads); - assert_ne!(execution_output.visited_pcs, HashMap::default()); + assert_ne!(execution_output.visited_pcs, VisitedPcsSet::default()); // Failed execution. let tx_index = 1; worker_executor.execute(tx_index); // No write was made by the transaction. assert_eq!( - safe_versioned_state.pin_version(tx_index).get_nonce_at(account_address).unwrap(), + safe_versioned_state + .pin_version::(tx_index) + .get_nonce_at(account_address) + .unwrap(), nonce!(1_u8) ); let execution_output = worker_executor.execution_outputs[tx_index].lock().unwrap(); @@ -402,21 +406,24 @@ fn test_worker_execute(max_resource_bounds: ResourceBoundsMapping) { }; assert_eq!(execution_output.reads, reads); assert_eq!(execution_output.writes, StateMaps::default()); - assert_eq!(execution_output.visited_pcs, HashMap::default()); + assert_eq!(execution_output.visited_pcs, VisitedPcsSet::default()); // Reverted execution. let tx_index = 2; worker_executor.execute(tx_index); // Read a write made by the transaction. assert_eq!( - safe_versioned_state.pin_version(tx_index).get_nonce_at(account_address).unwrap(), + safe_versioned_state + .pin_version::(tx_index) + .get_nonce_at(account_address) + .unwrap(), nonce!(2_u8) ); let execution_output = worker_executor.execution_outputs[tx_index].lock().unwrap(); let execution_output = execution_output.as_ref().unwrap(); assert!(execution_output.result.as_ref().unwrap().is_reverted()); assert_ne!(execution_output.writes, StateMaps::default()); - assert_ne!(execution_output.visited_pcs, HashMap::default()); + assert_ne!(execution_output.visited_pcs, VisitedPcsSet::default()); // Validate status change. for tx_index in 0..3 { @@ -474,7 +481,7 @@ fn test_worker_validate(max_resource_bounds: ResourceBoundsMapping) { .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = WorkerExecutor::new( + let worker_executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new( safe_versioned_state.clone(), &txs, &block_context, @@ -500,7 +507,7 @@ fn test_worker_validate(max_resource_bounds: ResourceBoundsMapping) { // Verify writes exist in state. assert_eq!( safe_versioned_state - .pin_version(tx_index) + .pin_version::(tx_index) .get_storage_at(test_contract_address, storage_key) .unwrap(), storage_value0 @@ -515,7 +522,7 @@ fn test_worker_validate(max_resource_bounds: ResourceBoundsMapping) { // Verify writes were removed. assert_eq!( safe_versioned_state - .pin_version(tx_index) + .pin_version::(tx_index) .get_storage_at(test_contract_address, storage_key) .unwrap(), storage_value0 @@ -587,7 +594,7 @@ fn test_deploy_before_declare( .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = + let worker_executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new(safe_versioned_state, &txs, &block_context, Mutex::new(&mut bouncer)); // Creates 2 active tasks. @@ -659,7 +666,7 @@ fn test_worker_commit_phase(max_resource_bounds: ResourceBoundsMapping) { .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = + let worker_executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new(safe_versioned_state, &txs, &block_context, Mutex::new(&mut bouncer)); // Try to commit before any transaction is ready. @@ -749,7 +756,7 @@ fn test_worker_commit_phase_with_halt() { .collect::>(); let mut bouncer = Bouncer::new(block_context.bouncer_config.clone()); - let worker_executor = + let worker_executor: WorkerExecutor<'_, _, VisitedPcsSet> = WorkerExecutor::new(safe_versioned_state, &txs, &block_context, Mutex::new(&mut bouncer)); // Creates 2 active tasks. diff --git a/crates/blockifier/src/execution/contract_address_test.rs b/crates/blockifier/src/execution/contract_address_test.rs index 405360da1e..55b65cc3c5 100644 --- a/crates/blockifier/src/execution/contract_address_test.rs +++ b/crates/blockifier/src/execution/contract_address_test.rs @@ -9,6 +9,7 @@ use crate::execution::call_info::{CallExecution, Retdata}; use crate::execution::entry_point::CallEntryPoint; use crate::retdata; use crate::state::cached_state::CachedState; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -27,7 +28,7 @@ fn test_calculate_contract_address() { constructor_calldata: &Calldata, calldata: Calldata, deployer_address: ContractAddress, - state: &mut CachedState, + state: &mut CachedState, ) { let versioned_constants = VersionedConstants::create_for_testing(); let entry_point_call = CallEntryPoint { diff --git a/crates/blockifier/src/execution/entry_point_execution.rs b/crates/blockifier/src/execution/entry_point_execution.rs index e63b4085d5..147d2caef9 100644 --- a/crates/blockifier/src/execution/entry_point_execution.rs +++ b/crates/blockifier/src/execution/entry_point_execution.rs @@ -1,5 +1,3 @@ -use std::collections::HashSet; - use cairo_vm::types::builtin_name::BuiltinName; use cairo_vm::types::layout_name::LayoutName; use cairo_vm::types::relocatable::{MaybeRelocatable, Relocatable}; @@ -114,7 +112,7 @@ fn register_visited_pcs( program_segment_size: usize, bytecode_length: usize, ) -> EntryPointExecutionResult<()> { - let mut class_visited_pcs = HashSet::new(); + let mut class_visited_pcs = Vec::new(); // Relocate the trace, putting the program segment at address 1 and the execution segment right // after it. // TODO(lior): Avoid unnecessary relocation once the VM has a non-relocated `get_trace()` @@ -131,7 +129,7 @@ fn register_visited_pcs( // Jumping to a PC that is not inside the bytecode is possible. For example, to obtain // the builtin costs. Filter out these values. if real_pc < bytecode_length { - class_visited_pcs.insert(real_pc); + class_visited_pcs.push(real_pc); } } state.add_visited_pcs(class_hash, &class_visited_pcs); diff --git a/crates/blockifier/src/execution/entry_point_test.rs b/crates/blockifier/src/execution/entry_point_test.rs index 07abce7eb9..a9610f9b15 100644 --- a/crates/blockifier/src/execution/entry_point_test.rs +++ b/crates/blockifier/src/execution/entry_point_test.rs @@ -12,6 +12,7 @@ use crate::context::ChainInfo; use crate::execution::call_info::{CallExecution, CallInfo, Retdata}; use crate::execution::entry_point::CallEntryPoint; use crate::state::cached_state::CachedState; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -187,7 +188,7 @@ fn test_storage_var() { /// Runs test scenarios that could fail the OS run and therefore must be caught in the Blockifier. fn run_security_test( - state: &mut CachedState, + state: &mut CachedState, security_contract: FeatureContract, expected_error: &str, entry_point_name: &str, diff --git a/crates/blockifier/src/state.rs b/crates/blockifier/src/state.rs index e027d2b301..3bef337429 100644 --- a/crates/blockifier/src/state.rs +++ b/crates/blockifier/src/state.rs @@ -4,3 +4,4 @@ pub mod error_format_test; pub mod errors; pub mod global_cache; pub mod state_api; +pub mod visited_pcs; diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index 74ba043d63..4941154ae6 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -7,6 +7,7 @@ use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; +use super::visited_pcs::VisitedPcs; use crate::abi::abi_utils::get_fee_token_var_address; use crate::context::TransactionContext; use crate::execution::contract_class::ContractClass; @@ -21,30 +22,28 @@ mod test; pub type ContractClassMapping = HashMap; -pub type VisitedPcs = HashMap>; - /// Caches read and write requests. /// /// Writer functionality is builtin, whereas Reader functionality is injected through /// initialization. #[derive(Debug)] -pub struct CachedState { +pub struct CachedState { pub state: S, // Invariant: read/write access is managed by CachedState. // Using interior mutability to update caches during `State`'s immutable getters. pub(crate) cache: RefCell, pub(crate) class_hash_to_class: RefCell, /// A map from class hash to the set of PC values that were visited in the class. - pub visited_pcs: VisitedPcs, + pub visited_pcs: V, } -impl CachedState { +impl CachedState { pub fn new(state: S) -> Self { Self { state, cache: RefCell::new(StateCache::default()), class_hash_to_class: RefCell::new(HashMap::default()), - visited_pcs: VisitedPcs::default(), + visited_pcs: V::default(), } } @@ -75,9 +74,9 @@ impl CachedState { self.class_hash_to_class.get_mut().extend(local_contract_cache_updates); } - pub fn update_visited_pcs_cache(&mut self, visited_pcs: &VisitedPcs) { - for (class_hash, class_visited_pcs) in visited_pcs { - self.add_visited_pcs(*class_hash, class_visited_pcs); + pub fn update_visited_pcs_cache(&mut self, visited_pcs: &V) { + for (class_hash, class_visited_pcs) in visited_pcs.iter() { + V::add_visited_pcs(self, class_hash, class_visited_pcs.clone()) } } @@ -109,12 +108,14 @@ impl CachedState { } } -impl UpdatableState for CachedState { +impl UpdatableState for CachedState { + type T = V; + fn apply_writes( &mut self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping, - visited_pcs: &VisitedPcs, + visited_pcs: &V, ) { // TODO(Noa,15/5/24): Reconsider the clone. self.update_cache(writes, class_hash_to_class.clone()); @@ -123,13 +124,13 @@ impl UpdatableState for CachedState { } #[cfg(any(feature = "testing", test))] -impl From for CachedState { +impl From for CachedState { fn from(state_reader: S) -> Self { CachedState::new(state_reader) } } -impl StateReader for CachedState { +impl StateReader for CachedState { fn get_storage_at( &self, contract_address: ContractAddress, @@ -224,7 +225,7 @@ impl StateReader for CachedState { } } -impl State for CachedState { +impl State for CachedState { fn set_storage_at( &mut self, contract_address: ContractAddress, @@ -277,13 +278,18 @@ impl State for CachedState { Ok(()) } - fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet) { - self.visited_pcs.entry(class_hash).or_default().extend(pcs); + fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &Vec) { + self.visited_pcs.insert(&class_hash, pcs); } } #[cfg(any(feature = "testing", test))] -impl Default for CachedState { +impl Default + for CachedState< + crate::test_utils::dict_state_reader::DictStateReader, + super::visited_pcs::VisitedPcsSet, + > +{ fn default() -> Self { Self { state: Default::default(), @@ -506,14 +512,14 @@ impl<'a, S: StateReader + ?Sized> StateReader for MutRefState<'a, S> { } } -pub type TransactionalState<'a, U> = CachedState>; +pub type TransactionalState<'a, U, V> = CachedState, V>; -impl<'a, S: StateReader> TransactionalState<'a, S> { +impl<'a, S: StateReader, V: VisitedPcs> TransactionalState<'a, S, V> { /// Creates a transactional instance from the given updatable state. /// It allows performing buffered modifying actions on the given state, which /// will either all happen (will be updated in the state and committed) /// or none of them (will be discarded). - pub fn create_transactional(state: &mut S) -> TransactionalState<'_, S> { + pub fn create_transactional(state: &mut S) -> TransactionalState<'_, S, V> { CachedState::new(MutRefState::new(state)) } @@ -522,7 +528,7 @@ impl<'a, S: StateReader> TransactionalState<'a, S> { } /// Adds the ability to perform a transactional execution. -impl<'a, U: UpdatableState> TransactionalState<'a, U> { +impl<'a, V: VisitedPcs, U: UpdatableState> TransactionalState<'a, U, V> { /// Commits changes in the child (wrapping) state to its parent. pub fn commit(self) { let state = self.state.0; diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index 37c8e72ba5..1d80abdc4c 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -9,6 +9,7 @@ use starknet_api::{class_hash, contract_address, felt, patricia_key}; use crate::context::{BlockContext, ChainInfo}; use crate::state::cached_state::*; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -17,7 +18,7 @@ use crate::{compiled_class_hash, nonce, storage_key}; const CONTRACT_ADDRESS: &str = "0x100"; fn set_initial_state_values( - state: &mut CachedState, + state: &mut CachedState, class_hash_to_class: ContractClassMapping, nonce_initial_values: HashMap, class_hash_initial_values: HashMap, @@ -33,7 +34,7 @@ fn set_initial_state_values( #[test] fn get_uninitialized_storage_value() { - let state: CachedState = CachedState::default(); + let state: CachedState = CachedState::default(); let contract_address = contract_address!("0x1"); let key = storage_key!(0x10_u16); @@ -49,13 +50,14 @@ fn get_and_set_storage_value() { let storage_val0: Felt = felt!("0x1"); let storage_val1: Felt = felt!("0x5"); - let mut state = CachedState::from(DictStateReader { - storage_view: HashMap::from([ - ((contract_address0, key0), storage_val0), - ((contract_address1, key1), storage_val1), - ]), - ..Default::default() - }); + let mut state: CachedState = + CachedState::from(DictStateReader { + storage_view: HashMap::from([ + ((contract_address0, key0), storage_val0), + ((contract_address1, key1), storage_val1), + ]), + ..Default::default() + }); assert_eq!(state.get_storage_at(contract_address0, key0).unwrap(), storage_val0); assert_eq!(state.get_storage_at(contract_address1, key1).unwrap(), storage_val1); @@ -98,7 +100,7 @@ fn cast_between_storage_mapping_types() { #[test] fn get_uninitialized_value() { - let state: CachedState = CachedState::default(); + let state: CachedState = CachedState::default(); let contract_address = contract_address!("0x1"); assert_eq!(state.get_nonce_at(contract_address).unwrap(), Nonce::default()); @@ -106,7 +108,8 @@ fn get_uninitialized_value() { #[test] fn declare_contract() { - let mut state = CachedState::from(DictStateReader { ..Default::default() }); + let mut state: CachedState = + CachedState::from(DictStateReader { ..Default::default() }); let test_contract = FeatureContract::TestContract(CairoVersion::Cairo0); let class_hash = test_contract.get_class_hash(); let contract_class = test_contract.get_class(); @@ -135,13 +138,14 @@ fn get_and_increment_nonce() { let contract_address2 = contract_address!("0x200"); let initial_nonce = Nonce(felt!(1_u8)); - let mut state = CachedState::from(DictStateReader { - address_to_nonce: HashMap::from([ - (contract_address1, initial_nonce), - (contract_address2, initial_nonce), - ]), - ..Default::default() - }); + let mut state: CachedState = + CachedState::from(DictStateReader { + address_to_nonce: HashMap::from([ + (contract_address1, initial_nonce), + (contract_address2, initial_nonce), + ]), + ..Default::default() + }); assert_eq!(state.get_nonce_at(contract_address1).unwrap(), initial_nonce); assert_eq!(state.get_nonce_at(contract_address2).unwrap(), initial_nonce); @@ -181,7 +185,7 @@ fn get_contract_class() { #[test] fn get_uninitialized_class_hash_value() { - let state: CachedState = CachedState::default(); + let state: CachedState = CachedState::default(); let valid_contract_address = contract_address!("0x1"); assert_eq!(state.get_class_hash_at(valid_contract_address).unwrap(), ClassHash::default()); @@ -190,7 +194,7 @@ fn get_uninitialized_class_hash_value() { #[test] fn set_and_get_contract_hash() { let contract_address = contract_address!("0x1"); - let mut state: CachedState = CachedState::default(); + let mut state: CachedState = CachedState::default(); let class_hash = class_hash!("0x10"); assert!(state.set_class_hash_at(contract_address, class_hash).is_ok()); @@ -199,7 +203,7 @@ fn set_and_get_contract_hash() { #[test] fn cannot_set_class_hash_to_uninitialized_contract() { - let mut state: CachedState = CachedState::default(); + let mut state: CachedState = CachedState::default(); let uninitialized_contract_address = ContractAddress::default(); let class_hash = class_hash!("0x100"); @@ -289,8 +293,8 @@ fn cached_state_state_diff_conversion() { assert_eq!(expected_state_diff, state.to_state_diff().unwrap().into()); } -fn create_state_changes_for_test( - state: &mut CachedState, +fn create_state_changes_for_test( + state: &mut CachedState, sender_address: Option, fee_token_address: ContractAddress, ) -> StateChanges { @@ -331,7 +335,7 @@ fn create_state_changes_for_test( fn test_from_state_changes_for_fee_charge( #[values(Some(contract_address!("0x102")), None)] sender_address: Option, ) { - let mut state: CachedState = CachedState::default(); + let mut state: CachedState = CachedState::default(); let fee_token_address = contract_address!("0x17"); let state_changes = create_state_changes_for_test(&mut state, sender_address, fee_token_address); @@ -352,7 +356,7 @@ fn test_state_changes_merge( ) { // Create a transactional state containing the `create_state_changes_for_test` logic, get the // state changes and then commit. - let mut state: CachedState = CachedState::default(); + let mut state: CachedState = CachedState::default(); let mut transactional_state = TransactionalState::create_transactional(&mut state); let block_context = BlockContext::create_for_testing(); let fee_token_address = block_context.chain_info.fee_token_addresses.eth_fee_token_address; @@ -422,7 +426,7 @@ fn test_contract_cache_is_used() { let contract_class = test_contract.get_class(); let mut reader = DictStateReader::default(); reader.class_hash_to_class.insert(class_hash, contract_class.clone()); - let state = CachedState::new(reader); + let state: CachedState = CachedState::new(reader); // Assert local cache is initialized empty. assert!(state.class_hash_to_class.borrow().get(&class_hash).is_none()); diff --git a/crates/blockifier/src/state/state_api.rs b/crates/blockifier/src/state/state_api.rs index 8c6a1db559..3ca626aa72 100644 --- a/crates/blockifier/src/state/state_api.rs +++ b/crates/blockifier/src/state/state_api.rs @@ -1,10 +1,8 @@ -use std::collections::HashSet; - use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; -use super::cached_state::{ContractClassMapping, StateMaps, VisitedPcs}; +use super::cached_state::{ContractClassMapping, StateMaps}; use crate::abi::abi_utils::get_fee_token_var_address; use crate::abi::sierra_types::next_storage_key; use crate::execution::contract_class::ContractClass; @@ -107,15 +105,28 @@ pub trait State: StateReader { /// Marks the given set of PC values as visited for the given class hash. // TODO(lior): Once we have a BlockResources object, move this logic there. Make sure reverted // entry points do not affect the final set of PCs. - fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet); + fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &Vec); } /// A class defining the API for updating a state with transactions writes. pub trait UpdatableState: StateReader { + type T; + + fn apply_writes( + &mut self, + writes: &StateMaps, + class_hash_to_class: &ContractClassMapping, + visited_pcs: &Self::T, + ); +} + +pub trait UpdatableStatetTest: StateReader { + type T; + fn apply_writes( &mut self, writes: &StateMaps, class_hash_to_class: &ContractClassMapping, - visited_pcs: &VisitedPcs, + visited_pcs: &Self::T, ); } diff --git a/crates/blockifier/src/state/visited_pcs.rs b/crates/blockifier/src/state/visited_pcs.rs new file mode 100644 index 0000000000..6bfca65d0c --- /dev/null +++ b/crates/blockifier/src/state/visited_pcs.rs @@ -0,0 +1,89 @@ +use std::collections::hash_map::{Entry, IntoIter, Iter}; +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; + +use starknet_api::core::ClassHash; + +use super::state_api::State; + +/// This trait is used in `CachedState` to record visited pcs of an entry point call. +pub trait VisitedPcs +where + Self: Default + Debug, +{ + /// This is the type which contains visited program counters. + type T: Clone; + + fn new() -> Self; + + /// The function `insert` reads the program counters returned by the cairo vm trace. + /// + /// The elements of the vector `pcs` match the type of field `pc` in + /// [`cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry`] + fn insert(&mut self, class_hash: &ClassHash, pcs: &Vec); + + /// The function `extend` is used to extend an instance of `VisitedPcs` with another one. + fn extend(&mut self, class_hash: &ClassHash, pcs: &Self::T); + + /// This function returns an iterator of `VisitedPcs`. + fn iter(&self) -> impl Iterator; + + /// Get the recorded visited program counters for a specific `class_hash`. + fn entry(&mut self, class_hash: ClassHash) -> Entry<'_, ClassHash, Self::T>; + + /// Marks the given PC values as visited for the given class hash. + fn add_visited_pcs(state: &mut dyn State, class_hash: &ClassHash, pcs: Self::T); + + /// This function returns the program counters in a set. + fn to_set(pcs: Self::T) -> HashSet; +} + +#[derive(Debug, Default, PartialEq, Eq)] +pub struct VisitedPcsSet(HashMap>); +impl VisitedPcs for VisitedPcsSet { + type T = HashSet; + + fn new() -> Self { + VisitedPcsSet(HashMap::default()) + } + + fn insert(&mut self, class_hash: &ClassHash, pcs: &Vec) { + self.0.entry(*class_hash).or_default().extend(pcs); + } + + fn iter(&self) -> impl Iterator { + self.0.iter() + } + + fn entry(&mut self, class_hash: ClassHash) -> Entry<'_, ClassHash, HashSet> { + self.0.entry(class_hash) + } + + fn add_visited_pcs(state: &mut dyn State, class_hash: &ClassHash, pcs: Self::T) { + state.add_visited_pcs(*class_hash, &Vec::from_iter(pcs)); + } + + fn extend(&mut self, class_hash: &ClassHash, pcs: &Self::T) { + self.0.entry(*class_hash).or_default().extend(pcs); + } + + fn to_set(pcs: Self::T) -> HashSet { + pcs + } +} +impl IntoIterator for VisitedPcsSet { + type Item = (ClassHash, HashSet); + type IntoIter = IntoIter>; + + fn into_iter(self) -> IntoIter> { + self.0.into_iter() + } +} +impl<'a> IntoIterator for &'a VisitedPcsSet { + type Item = (&'a ClassHash, &'a HashSet); + type IntoIter = Iter<'a, ClassHash, HashSet>; + + fn into_iter(self) -> Iter<'a, ClassHash, HashSet> { + self.0.iter() + } +} diff --git a/crates/blockifier/src/test_utils/initial_test_state.rs b/crates/blockifier/src/test_utils/initial_test_state.rs index 6e0268cb29..6bf2fc56b3 100644 --- a/crates/blockifier/src/test_utils/initial_test_state.rs +++ b/crates/blockifier/src/test_utils/initial_test_state.rs @@ -7,6 +7,7 @@ use strum::IntoEnumIterator; use crate::abi::abi_utils::get_fee_token_var_address; use crate::context::ChainInfo; use crate::state::cached_state::CachedState; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::CairoVersion; @@ -40,7 +41,7 @@ pub fn test_state_inner( initial_balances: u128, contract_instances: &[(FeatureContract, u16)], erc20_contract_version: CairoVersion, -) -> CachedState { +) -> CachedState { let mut class_hash_to_class = HashMap::new(); let mut address_to_class_hash = HashMap::new(); @@ -87,6 +88,6 @@ pub fn test_state( chain_info: &ChainInfo, initial_balances: u128, contract_instances: &[(FeatureContract, u16)], -) -> CachedState { +) -> CachedState { test_state_inner(chain_info, initial_balances, contract_instances, CairoVersion::Cairo0) } diff --git a/crates/blockifier/src/test_utils/transfers_generator.rs b/crates/blockifier/src/test_utils/transfers_generator.rs index 3d3a6a911d..1bd11e587f 100644 --- a/crates/blockifier/src/test_utils/transfers_generator.rs +++ b/crates/blockifier/src/test_utils/transfers_generator.rs @@ -10,6 +10,7 @@ use crate::blockifier::config::{ConcurrencyConfig, TransactionExecutorConfig}; use crate::blockifier::transaction_executor::TransactionExecutor; use crate::context::{BlockContext, ChainInfo}; use crate::invoke_tx_args; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -73,7 +74,7 @@ pub enum RecipientGeneratorType { pub struct TransfersGenerator { account_addresses: Vec, chain_info: ChainInfo, - executor: TransactionExecutor, + executor: TransactionExecutor, nonce_manager: NonceManager, sender_index: usize, random_recipient_generator: Option, diff --git a/crates/blockifier/src/transaction/account_transaction.rs b/crates/blockifier/src/transaction/account_transaction.rs index 51ee47b660..3390a8f8ed 100644 --- a/crates/blockifier/src/transaction/account_transaction.rs +++ b/crates/blockifier/src/transaction/account_transaction.rs @@ -21,6 +21,7 @@ use crate::fee::gas_usage::{compute_discounted_gas_from_gas_vector, estimate_min use crate::retdata; use crate::state::cached_state::{StateChanges, TransactionalState}; use crate::state::state_api::{State, StateReader, UpdatableState}; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::constants; use crate::transaction::errors::{ TransactionExecutionError, TransactionFeeError, TransactionPreValidationError, @@ -303,9 +304,9 @@ impl AccountTransaction { Ok(()) } - fn handle_fee( + fn handle_fee( &self, - state: &mut TransactionalState<'_, S>, + state: &mut TransactionalState<'_, S, V>, tx_context: Arc, actual_fee: Fee, charge_fee: bool, @@ -370,8 +371,8 @@ impl AccountTransaction { /// manipulates the state to avoid that part. /// Note: the returned transfer call info is partial, and should be completed at the commit /// stage, as well as the actual sequencer balance. - fn concurrency_execute_fee_transfer( - state: &mut TransactionalState<'_, S>, + fn concurrency_execute_fee_transfer( + state: &mut TransactionalState<'_, S, V>, tx_context: Arc, actual_fee: Fee, ) -> TransactionExecutionResult { @@ -379,7 +380,8 @@ impl AccountTransaction { let fee_address = block_context.chain_info.fee_token_address(&tx_info.fee_type()); let (sequencer_balance_key_low, sequencer_balance_key_high) = get_sequencer_balance_keys(block_context); - let mut transfer_state = TransactionalState::create_transactional(state); + let mut transfer_state: TransactionalState<'_, _, V> = + TransactionalState::create_transactional(state); // Set the initial sequencer balance to avoid tarnishing the read-set of the transaction. let cache = transfer_state.cache.get_mut(); @@ -411,9 +413,9 @@ impl AccountTransaction { } } - fn run_non_revertible( + fn run_non_revertible( &self, - state: &mut TransactionalState<'_, S>, + state: &mut TransactionalState<'_, S, V>, tx_context: Arc, remaining_gas: &mut u64, validate: bool, @@ -474,9 +476,9 @@ impl AccountTransaction { } } - fn run_revertible( + fn run_revertible( &self, - state: &mut TransactionalState<'_, S>, + state: &mut TransactionalState<'_, S, V>, tx_context: Arc, remaining_gas: &mut u64, validate: bool, @@ -508,7 +510,8 @@ impl AccountTransaction { // Create copies of state and resources for the execution. // Both will be rolled back if the execution is reverted or committed upon success. let mut execution_resources = resources.clone(); - let mut execution_state = TransactionalState::create_transactional(state); + let mut execution_state: TransactionalState<'_, _, V> = + TransactionalState::create_transactional(state); let execution_result = self.run_execute( &mut execution_state, @@ -615,9 +618,9 @@ impl AccountTransaction { } /// Runs validation and execution. - fn run_or_revert( + fn run_or_revert( &self, - state: &mut TransactionalState<'_, S>, + state: &mut TransactionalState<'_, S, V>, remaining_gas: &mut u64, tx_context: Arc, validate: bool, @@ -631,10 +634,10 @@ impl AccountTransaction { } } -impl ExecutableTransaction for AccountTransaction { +impl> ExecutableTransaction for AccountTransaction { fn execute_raw( &self, - state: &mut TransactionalState<'_, U>, + state: &mut TransactionalState<'_, U, V>, block_context: &BlockContext, execution_flags: ExecutionFlags, ) -> TransactionExecutionResult { diff --git a/crates/blockifier/src/transaction/execution_flavors_test.rs b/crates/blockifier/src/transaction/execution_flavors_test.rs index e186ad8151..2a1c29026e 100644 --- a/crates/blockifier/src/transaction/execution_flavors_test.rs +++ b/crates/blockifier/src/transaction/execution_flavors_test.rs @@ -13,6 +13,7 @@ use crate::execution::syscalls::SyscallSelector; use crate::fee::fee_utils::get_fee_by_gas_vector; use crate::state::cached_state::CachedState; use crate::state::state_api::StateReader; +use crate::state::visited_pcs::{VisitedPcs, VisitedPcsSet}; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::dict_state_reader::DictStateReader; use crate::test_utils::initial_test_state::test_state; @@ -34,7 +35,7 @@ use crate::{invoke_tx_args, nonce}; const VALIDATE_GAS_OVERHEAD: u64 = 21; struct FlavorTestInitialState { - pub state: CachedState, + pub state: CachedState, pub account_address: ContractAddress, pub faulty_account_address: ContractAddress, pub test_contract_address: ContractAddress, @@ -64,9 +65,9 @@ fn create_flavors_test_state( /// Checks that balance of the account decreased if and only if `charge_fee` is true. /// Returns the new balance. -fn check_balance( +fn check_balance( current_balance: Felt, - state: &mut CachedState, + state: &mut CachedState, account_address: ContractAddress, chain_info: &ChainInfo, fee_type: &FeeType, @@ -185,10 +186,10 @@ fn test_simulate_validate_charge_fee_pre_validate( // First scenario: invalid nonce. Regardless of flags, should fail. let invalid_nonce = nonce!(7_u8); let account_nonce = state.get_nonce_at(account_address).unwrap(); - let result = account_invoke_tx( + let account_tx = account_invoke_tx( invoke_tx_args! {nonce: invalid_nonce, ..pre_validation_base_args.clone()}, - ) - .execute(&mut state, &block_context, charge_fee, validate); + ); + let result = account_tx.execute(&mut state, &block_context, charge_fee, validate); assert_matches!( result.unwrap_err(), TransactionExecutionError::TransactionPreValidationError( @@ -210,13 +211,13 @@ fn test_simulate_validate_charge_fee_pre_validate( validate, &fee_type, ); - let result = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee: Fee(10), resource_bounds: l1_resource_bounds(10, 10), nonce: nonce_manager.next(account_address), ..pre_validation_base_args.clone() - }) - .execute(&mut state, &block_context, charge_fee, validate); + }); + let result = account_tx.execute(&mut state, &block_context, charge_fee, validate); if !charge_fee { check_gas_and_fee( &block_context, @@ -254,13 +255,13 @@ fn test_simulate_validate_charge_fee_pre_validate( // TODO(Ori, 1/2/2024): Write an indicative expect message explaining why the conversion works. let balance_over_gas_price: u64 = (BALANCE / gas_price).try_into().expect("Failed to convert u128 to u64."); - let result = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee: Fee(BALANCE + 1), resource_bounds: l1_resource_bounds(balance_over_gas_price + 10, gas_price.into()), nonce: nonce_manager.next(account_address), ..pre_validation_base_args.clone() - }) - .execute(&mut state, &block_context, charge_fee, validate); + }); + let result = account_tx.execute(&mut state, &block_context, charge_fee, validate); if !charge_fee { check_gas_and_fee( &block_context, @@ -295,12 +296,12 @@ fn test_simulate_validate_charge_fee_pre_validate( // Fourth scenario: L1 gas price bound lower than the price on the block. if !is_deprecated { - let result = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { resource_bounds: l1_resource_bounds(MAX_L1_GAS_AMOUNT, u128::from(gas_price) - 1), nonce: nonce_manager.next(account_address), ..pre_validation_base_args - }) - .execute(&mut state, &block_context, charge_fee, validate); + }); + let result = account_tx.execute(&mut state, &block_context, charge_fee, validate); if !charge_fee { check_gas_and_fee( &block_context, @@ -355,7 +356,7 @@ fn test_simulate_validate_charge_fee_fail_validate( validate, &fee_type, ); - let result = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee, resource_bounds: max_resource_bounds, signature: TransactionSignature(vec![ @@ -367,8 +368,8 @@ fn test_simulate_validate_charge_fee_fail_validate( version, nonce: nonce_manager.next(faulty_account_address), only_query, - }) - .execute(&mut falliable_state, &block_context, charge_fee, validate); + }); + let result = account_tx.execute(&mut falliable_state, &block_context, charge_fee, validate); if !validate { // The reported fee should be the actual cost, regardless of whether or not fee is charged. check_gas_and_fee( @@ -434,13 +435,13 @@ fn test_simulate_validate_charge_fee_mid_execution( validate, &fee_type, ); - let tx_execution_info = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { calldata: recurse_calldata(test_contract_address, true, 3), nonce: nonce_manager.next(account_address), ..execution_base_args.clone() - }) - .execute(&mut state, &block_context, charge_fee, validate) - .unwrap(); + }); + let tx_execution_info = + account_tx.execute(&mut state, &block_context, charge_fee, validate).unwrap(); assert!(tx_execution_info.is_reverted()); check_gas_and_fee( &block_context, @@ -474,15 +475,15 @@ fn test_simulate_validate_charge_fee_mid_execution( validate, &fee_type, ); - let tx_execution_info = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee: fee_bound, resource_bounds: l1_resource_bounds(gas_bound, gas_price.into()), calldata: recurse_calldata(test_contract_address, false, 1000), nonce: nonce_manager.next(account_address), ..execution_base_args.clone() - }) - .execute(&mut state, &block_context, charge_fee, validate) - .unwrap(); + }); + let tx_execution_info = + account_tx.execute(&mut state, &block_context, charge_fee, validate).unwrap(); assert_eq!(tx_execution_info.is_reverted(), charge_fee); if charge_fee { assert!(tx_execution_info.revert_error.clone().unwrap().contains("no remaining steps")); @@ -526,15 +527,15 @@ fn test_simulate_validate_charge_fee_mid_execution( GasVector::from_l1_gas(block_limit_gas.into()), &fee_type, ); - let tx_execution_info = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee: huge_fee, resource_bounds: l1_resource_bounds(huge_gas_limit, gas_price.into()), calldata: recurse_calldata(test_contract_address, false, 10000), nonce: nonce_manager.next(account_address), ..execution_base_args - }) - .execute(&mut state, &low_step_block_context, charge_fee, validate) - .unwrap(); + }); + let tx_execution_info = + account_tx.execute(&mut state, &low_step_block_context, charge_fee, validate).unwrap(); assert!(tx_execution_info.revert_error.clone().unwrap().contains("no remaining steps")); // Complete resources used are reported as transaction_receipt.resources; but only the charged // final fee is shown in actual_fee. As a sanity check, verify that the fee derived directly @@ -607,7 +608,7 @@ fn test_simulate_validate_charge_fee_post_execution( validate, &fee_type, ); - let tx_execution_info = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee: just_not_enough_fee_bound, resource_bounds: l1_resource_bounds(just_not_enough_gas_bound, gas_price.into()), calldata: recurse_calldata(test_contract_address, false, 1000), @@ -615,9 +616,9 @@ fn test_simulate_validate_charge_fee_post_execution( sender_address: account_address, version, only_query, - }) - .execute(&mut state, &block_context, charge_fee, validate) - .unwrap(); + }); + let tx_execution_info = + account_tx.execute(&mut state, &block_context, charge_fee, validate).unwrap(); assert_eq!(tx_execution_info.is_reverted(), charge_fee); if charge_fee { assert!(tx_execution_info.revert_error.clone().unwrap().starts_with(if is_deprecated { @@ -672,7 +673,7 @@ fn test_simulate_validate_charge_fee_post_execution( felt!(0_u8), ], ); - let tx_execution_info = account_invoke_tx(invoke_tx_args! { + let account_tx = account_invoke_tx(invoke_tx_args! { max_fee: actual_fee, resource_bounds: l1_resource_bounds(success_actual_gas, gas_price.into()), calldata: transfer_calldata, @@ -680,9 +681,9 @@ fn test_simulate_validate_charge_fee_post_execution( sender_address: account_address, version, only_query, - }) - .execute(&mut state, &block_context, charge_fee, validate) - .unwrap(); + }); + let tx_execution_info = + account_tx.execute(&mut state, &block_context, charge_fee, validate).unwrap(); assert_eq!(tx_execution_info.is_reverted(), charge_fee); if charge_fee { assert!( diff --git a/crates/blockifier/src/transaction/test_utils.rs b/crates/blockifier/src/transaction/test_utils.rs index a5a353a056..88a4605703 100644 --- a/crates/blockifier/src/transaction/test_utils.rs +++ b/crates/blockifier/src/transaction/test_utils.rs @@ -14,6 +14,7 @@ use crate::context::{BlockContext, ChainInfo}; use crate::execution::contract_class::{ClassInfo, ContractClass}; use crate::state::cached_state::CachedState; use crate::state::state_api::State; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::deploy_account::{deploy_account_tx, DeployAccountTxArgs}; @@ -79,7 +80,7 @@ pub fn block_context() -> BlockContext { /// Struct containing the data usually needed to initialize a test. pub struct TestInitData { - pub state: CachedState, + pub state: CachedState, pub account_address: ContractAddress, pub contract_address: ContractAddress, pub nonce_manager: NonceManager, @@ -88,7 +89,7 @@ pub struct TestInitData { /// Deploys a new account with the given class hash, funds with both fee tokens, and returns the /// deploy tx and address. pub fn deploy_and_fund_account( - state: &mut CachedState, + state: &mut CachedState, nonce_manager: &mut NonceManager, chain_info: &ChainInfo, deploy_tx_args: DeployAccountTxArgs, @@ -268,11 +269,12 @@ pub fn account_invoke_tx(invoke_args: InvokeTxArgs) -> AccountTransaction { } pub fn run_invoke_tx( - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, invoke_args: InvokeTxArgs, ) -> TransactionExecutionResult { - account_invoke_tx(invoke_args).execute(state, block_context, true, true) + let account_tx = account_invoke_tx(invoke_args); + account_tx.execute(state, block_context, true, true) } /// Creates a `ResourceBoundsMapping` with the given `max_amount` and `max_price` for L1 gas limits. diff --git a/crates/blockifier/src/transaction/transaction_execution.rs b/crates/blockifier/src/transaction/transaction_execution.rs index e617ad1c5f..4256e94fe5 100644 --- a/crates/blockifier/src/transaction/transaction_execution.rs +++ b/crates/blockifier/src/transaction/transaction_execution.rs @@ -11,6 +11,7 @@ use crate::execution::entry_point::EntryPointExecutionContext; use crate::fee::actual_cost::TransactionReceipt; use crate::state::cached_state::TransactionalState; use crate::state::state_api::UpdatableState; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::account_transaction::AccountTransaction; use crate::transaction::errors::TransactionFeeError; use crate::transaction::objects::{ @@ -100,10 +101,10 @@ impl TransactionInfoCreator for Transaction { } } -impl ExecutableTransaction for L1HandlerTransaction { +impl> ExecutableTransaction for L1HandlerTransaction { fn execute_raw( &self, - state: &mut TransactionalState<'_, U>, + state: &mut TransactionalState<'_, U, V>, block_context: &BlockContext, _execution_flags: ExecutionFlags, ) -> TransactionExecutionResult { @@ -151,10 +152,10 @@ impl ExecutableTransaction for L1HandlerTransaction { } } -impl ExecutableTransaction for Transaction { +impl> ExecutableTransaction for Transaction { fn execute_raw( &self, - state: &mut TransactionalState<'_, U>, + state: &mut TransactionalState<'_, U, V>, block_context: &BlockContext, execution_flags: ExecutionFlags, ) -> TransactionExecutionResult { diff --git a/crates/blockifier/src/transaction/transactions.rs b/crates/blockifier/src/transaction/transactions.rs index 4e3188c150..c4f8121791 100644 --- a/crates/blockifier/src/transaction/transactions.rs +++ b/crates/blockifier/src/transaction/transactions.rs @@ -21,6 +21,7 @@ use crate::execution::execution_utils::execute_deployment; use crate::state::cached_state::TransactionalState; use crate::state::errors::StateError; use crate::state::state_api::{State, UpdatableState}; +use crate::state::visited_pcs::VisitedPcs; use crate::transaction::constants; use crate::transaction::errors::TransactionExecutionError; use crate::transaction::objects::{ @@ -48,7 +49,7 @@ pub struct ExecutionFlags { pub concurrency_mode: bool, } -pub trait ExecutableTransaction: Sized { +pub trait ExecutableTransaction>: Sized { /// Executes the transaction in a transactional manner /// (if it fails, given state does not modify). fn execute( @@ -84,7 +85,7 @@ pub trait ExecutableTransaction: Sized { /// for automatic handling of such cases. fn execute_raw( &self, - state: &mut TransactionalState<'_, U>, + state: &mut TransactionalState<'_, U, V>, block_context: &BlockContext, execution_flags: ExecutionFlags, ) -> TransactionExecutionResult; diff --git a/crates/blockifier/src/transaction/transactions_test.rs b/crates/blockifier/src/transaction/transactions_test.rs index e21c90de2e..dfb0532f69 100644 --- a/crates/blockifier/src/transaction/transactions_test.rs +++ b/crates/blockifier/src/transaction/transactions_test.rs @@ -41,6 +41,7 @@ use crate::fee::gas_usage::{ use crate::state::cached_state::{CachedState, StateChangesCount, TransactionalState}; use crate::state::errors::StateError; use crate::state::state_api::{State, StateReader}; +use crate::state::visited_pcs::VisitedPcsSet; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::declare::declare_tx; use crate::test_utils::deploy_account::deploy_account_tx; @@ -248,7 +249,7 @@ fn get_expected_cairo_resources( /// and the sequencer (in both fee types) are as expected (assuming the initial sequencer balances /// are zero). fn validate_final_balances( - state: &mut CachedState, + state: &mut CachedState, chain_info: &ChainInfo, expected_actual_fee: Fee, erc20_account_balance_key: StorageKey, @@ -496,7 +497,7 @@ fn test_invoke_tx( // Verifies the storage after each invoke execution in test_invoke_tx_advanced_operations. fn verify_storage_after_invoke_advanced_operations( - state: &mut CachedState, + state: &mut CachedState, contract_address: ContractAddress, account_address: ContractAddress, index: Felt, @@ -740,14 +741,14 @@ fn test_state_get_fee_token_balance( } fn assert_failure_if_resource_bounds_exceed_balance( - state: &mut CachedState, + state: &mut CachedState, block_context: &BlockContext, invalid_tx: AccountTransaction, ) { match block_context.to_tx_context(&invalid_tx).tx_info { TransactionInfo::Deprecated(context) => { assert_matches!( - invalid_tx.execute(state, block_context, true, true).unwrap_err(), + invalid_tx.execute( state, block_context, true, true).unwrap_err(), TransactionExecutionError::TransactionPreValidationError( TransactionPreValidationError::TransactionFeeError( TransactionFeeError::MaxFeeExceedsBalance{ max_fee, .. })) @@ -757,7 +758,7 @@ fn assert_failure_if_resource_bounds_exceed_balance( TransactionInfo::Current(context) => { let l1_bounds = context.l1_resource_bounds().unwrap(); assert_matches!( - invalid_tx.execute(state, block_context, true, true).unwrap_err(), + invalid_tx.execute( state, block_context, true, true).unwrap_err(), TransactionExecutionError::TransactionPreValidationError( TransactionPreValidationError::TransactionFeeError( TransactionFeeError::L1GasBoundsExceedBalance{ max_amount, max_price, .. })) @@ -981,7 +982,8 @@ fn test_invalid_nonce( calldata: create_trivial_calldata(test_contract.get_instance_address(0)), resource_bounds: max_resource_bounds, }; - let mut transactional_state = TransactionalState::create_transactional(state); + let mut transactional_state: TransactionalState<'_, _, VisitedPcsSet> = + TransactionalState::create_transactional(state); // Strict, negative flow: account nonce = 0, incoming tx nonce = 1. let invalid_nonce = nonce!(1_u8); diff --git a/crates/native_blockifier/src/py_block_executor.rs b/crates/native_blockifier/src/py_block_executor.rs index bf3c244eba..c29b7f5f00 100644 --- a/crates/native_blockifier/src/py_block_executor.rs +++ b/crates/native_blockifier/src/py_block_executor.rs @@ -8,6 +8,7 @@ use blockifier::context::{BlockContext, ChainInfo, FeeTokenAddresses}; use blockifier::execution::call_info::CallInfo; use blockifier::state::cached_state::CachedState; use blockifier::state::global_cache::GlobalContractCache; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::transaction::objects::{GasVector, ResourcesMapping, TransactionExecutionInfo}; use blockifier::transaction::transaction_execution::Transaction; use blockifier::versioned_constants::VersionedConstants; @@ -82,7 +83,7 @@ pub struct PyBlockExecutor { pub tx_executor_config: TransactionExecutorConfig, pub chain_info: ChainInfo, pub versioned_constants: VersionedConstants, - pub tx_executor: Option>, + pub tx_executor: Option>, /// `Send` trait is required for `pyclass` compatibility as Python objects must be threadsafe. pub storage: Box, pub global_contract_cache: GlobalContractCache, @@ -370,7 +371,7 @@ impl PyBlockExecutor { } impl PyBlockExecutor { - pub fn tx_executor(&mut self) -> &mut TransactionExecutor { + pub fn tx_executor(&mut self) -> &mut TransactionExecutor { self.tx_executor.as_mut().expect("Transaction executor should be initialized") } diff --git a/crates/native_blockifier/src/py_test_utils.rs b/crates/native_blockifier/src/py_test_utils.rs index 0e66423790..e5c7fccc4a 100644 --- a/crates/native_blockifier/src/py_test_utils.rs +++ b/crates/native_blockifier/src/py_test_utils.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use blockifier::execution::contract_class::ContractClassV0; use blockifier::state::cached_state::CachedState; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::test_utils::dict_state_reader::DictStateReader; use starknet_api::core::ClassHash; use starknet_api::{class_hash, felt}; @@ -12,7 +13,7 @@ pub const TOKEN_FOR_TESTING_CONTRACT_PATH: &str = "./src/starkware/starknet/core/test_contract/starknet_compiled_contracts_lib/starkware/\ starknet/core/test_contract/token_for_testing.json"; -pub fn create_py_test_state() -> CachedState { +pub fn create_py_test_state() -> CachedState { let class_hash_to_class = HashMap::from([( class_hash!(TOKEN_FOR_TESTING_CLASS_HASH), ContractClassV0::from_file(TOKEN_FOR_TESTING_CONTRACT_PATH).into(), diff --git a/crates/native_blockifier/src/py_validator.rs b/crates/native_blockifier/src/py_validator.rs index 8398e48c1a..3b547c3e3a 100644 --- a/crates/native_blockifier/src/py_validator.rs +++ b/crates/native_blockifier/src/py_validator.rs @@ -2,6 +2,7 @@ use blockifier::blockifier::stateful_validator::{StatefulValidator, StatefulVali use blockifier::bouncer::BouncerConfig; use blockifier::context::BlockContext; use blockifier::state::cached_state::CachedState; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::transaction::account_transaction::AccountTransaction; use blockifier::transaction::objects::TransactionInfoCreator; use blockifier::transaction::transaction_types::TransactionType; @@ -21,7 +22,7 @@ use crate::state_readers::py_state_reader::PyStateReader; #[pyclass] pub struct PyValidator { - pub stateful_validator: StatefulValidator, + pub stateful_validator: StatefulValidator, pub max_nonce_for_validation_skip: Nonce, } diff --git a/crates/native_blockifier/src/state_readers/papyrus_state_test.rs b/crates/native_blockifier/src/state_readers/papyrus_state_test.rs index e999276084..89a5ba133c 100644 --- a/crates/native_blockifier/src/state_readers/papyrus_state_test.rs +++ b/crates/native_blockifier/src/state_readers/papyrus_state_test.rs @@ -7,6 +7,7 @@ use blockifier::retdata; use blockifier::state::cached_state::CachedState; use blockifier::state::global_cache::{GlobalContractCache, GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST}; use blockifier::state::state_api::StateReader; +use blockifier::state::visited_pcs::VisitedPcsSet; use blockifier::test_utils::contracts::FeatureContract; use blockifier::test_utils::{trivial_external_entry_point_new, CairoVersion}; use indexmap::IndexMap; @@ -56,7 +57,7 @@ fn test_entry_point_with_papyrus_state() -> papyrus_storage::StorageResult<()> { block_number, GlobalContractCache::new(GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST), ); - let mut state = CachedState::from(papyrus_reader); + let mut state: CachedState<_, VisitedPcsSet> = CachedState::from(papyrus_reader); // Call entrypoint that want to write to storage, which updates the cached state's write cache. let key = felt!(1234_u16);