diff --git a/mpt_trie/src/debug_tools/diff.rs b/mpt_trie/src/debug_tools/diff.rs index 880984290..c08cb80d1 100644 --- a/mpt_trie/src/debug_tools/diff.rs +++ b/mpt_trie/src/debug_tools/diff.rs @@ -20,10 +20,9 @@ //! - Top-down will find the highest point of a structural divergence and report //! it. If there are multiple divergences, then only the one that is the //! highest in the trie will be reported. -//! - Bottom-up (not implemented) is a lot more complex to implement, but will -//! attempt to find the smallest structural trie difference between the trie. -//! If there are multiple differences, then this will likely be what you want -//! to use. +//! - Bottom-up is a lot more complex to implement, but will attempt to find the +//! smallest structural trie difference between the trie. If there are +//! multiple differences, then this will likely be what you want to use. use std::fmt::{self, Debug}; use std::{fmt::Display, ops::Deref}; @@ -48,17 +47,16 @@ fn get_key_piece_from_node(n: &Node) -> Nibbles { #[derive(Clone, Debug, Eq, Hash, PartialEq)] /// The difference between two Tries, represented as the highest -/// point of a structural divergence. +/// array of `DiffPoint`s. pub struct TrieDiff { - /// The highest point of structural divergence. - pub latest_diff_res: Option, - // TODO: Later add a second pass for finding diffs from the bottom up (`earliest_diff_res`). + /// Diff points between the two tries. + pub diff_points: Vec, } impl Display for TrieDiff { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(diff) = &self.latest_diff_res { - write!(f, "{}", diff)?; + for (index, diff_point) in self.diff_points.iter().enumerate() { + writeln!(f, "{}: {}\n", index, diff_point)?; } Ok(()) @@ -176,32 +174,30 @@ impl NodeInfo { } /// Create a diff between two tries. Will perform both types of diff searches -/// (top-down & bottom-up). -pub fn create_diff_between_tries(a: &HashedPartialTrie, b: &HashedPartialTrie) -> TrieDiff { +/// (top-down & bottom-up). It will not stop on first difference. +pub fn create_full_diff_between_tries(a: &HashedPartialTrie, b: &HashedPartialTrie) -> TrieDiff { TrieDiff { - latest_diff_res: find_latest_diff_point_between_tries(a, b), + diff_points: find_all_diff_points_between_tries(a, b), } } -// Only support `HashedPartialTrie` due to it being significantly faster to -// detect differences because of caching hashes. -fn find_latest_diff_point_between_tries( +fn find_all_diff_points_between_tries( a: &HashedPartialTrie, b: &HashedPartialTrie, -) -> Option { +) -> Vec { let state = DepthDiffPerCallState::new(a, b, Nibbles::default(), 0); - let mut longest_state = DepthNodeDiffState::default(); - - find_latest_diff_point_between_tries_rec(&state, &mut longest_state); - - // If there was a node diff, we always want to prioritize displaying this over a - // hash diff. The reasoning behind this is hash diffs can become sort of - // meaningless or misleading if the trie diverges at some point (eg. saying - // there is a hash diff deep in two separate trie structures doesn't make much - // sense). - longest_state - .longest_key_node_diff - .or(longest_state.longest_key_hash_diff) + let mut longest_states = Vec::new(); + + find_all_diff_points_between_tries_rec(&state, &mut longest_states); + + longest_states + .into_iter() + .filter_map(|longest_state| { + longest_state + .longest_key_node_diff + .or(longest_state.longest_key_hash_diff) + }) + .collect() } #[derive(Debug, Default)] @@ -298,9 +294,11 @@ impl<'a> DepthDiffPerCallState<'a> { } } -fn find_latest_diff_point_between_tries_rec( +// Search for the differences between two tries. Do not stop on first +// difference. +fn find_all_diff_points_between_tries_rec( state: &DepthDiffPerCallState, - depth_state: &mut DepthNodeDiffState, + depth_states: &mut Vec, ) -> DiffDetectionState { let a_hash = state.a.hash(); let b_hash = state.b.hash(); @@ -319,19 +317,28 @@ fn find_latest_diff_point_between_tries_rec( // Note that differences in a node's `value` will be picked up by a hash // mismatch. + let mut current_depth_node_diff_state: DepthNodeDiffState = Default::default(); if (a_type, a_key_piece) != (b_type, b_key_piece) { - depth_state.try_update_longest_divergence_key_node(state); + current_depth_node_diff_state.try_update_longest_divergence_key_node(state); + depth_states.push(current_depth_node_diff_state); DiffDetectionState::NodeTypesDiffer } else { match (&state.a.node, &state.b.node) { (Node::Empty, Node::Empty) => DiffDetectionState::NoDiffDetected, (Node::Hash(a_hash), Node::Hash(b_hash)) => { - create_diff_detection_state_based_from_hashes( + match create_diff_detection_state_based_from_hashes( a_hash, b_hash, &state.new_from_parent(state.a, state.b, &Nibbles::default()), - depth_state, - ) + &mut current_depth_node_diff_state, + ) { + DiffDetectionState::NoDiffDetected => DiffDetectionState::NoDiffDetected, + result @ (DiffDetectionState::HashDiffDetected + | DiffDetectionState::NodeTypesDiffer) => { + depth_states.push(current_depth_node_diff_state); + result + } + } } ( Node::Branch { @@ -346,13 +353,13 @@ fn find_latest_diff_point_between_tries_rec( let mut most_significant_diff_found = DiffDetectionState::NoDiffDetected; for i in 0..16_usize { - let res = find_latest_diff_point_between_tries_rec( + let res = find_all_diff_points_between_tries_rec( &state.new_from_parent( &a_children[i], &b_children[i], &Nibbles::from_nibble(i as u8), ), - depth_state, + depth_states, ); most_significant_diff_found = most_significant_diff_found.pick_most_significant_state(&res); @@ -364,8 +371,18 @@ fn find_latest_diff_point_between_tries_rec( ) { most_significant_diff_found } else { - // Also run a hash check if we haven't picked anything up yet. - create_diff_detection_state_based_from_hash_and_gen_hashes(state, depth_state) + // Also run a hash check if we haven't picked anything up + match create_diff_detection_state_based_from_hash_and_gen_hashes( + state, + &mut current_depth_node_diff_state, + ) { + DiffDetectionState::NoDiffDetected => DiffDetectionState::NoDiffDetected, + result @ (DiffDetectionState::HashDiffDetected + | DiffDetectionState::NodeTypesDiffer) => { + depth_states.push(current_depth_node_diff_state); + result + } + } } } ( @@ -377,12 +394,22 @@ fn find_latest_diff_point_between_tries_rec( nibbles: _b_nibs, child: b_child, }, - ) => find_latest_diff_point_between_tries_rec( + ) => find_all_diff_points_between_tries_rec( &state.new_from_parent(a_child, b_child, a_nibs), - depth_state, + depth_states, ), (Node::Leaf { .. }, Node::Leaf { .. }) => { - create_diff_detection_state_based_from_hash_and_gen_hashes(state, depth_state) + match create_diff_detection_state_based_from_hash_and_gen_hashes( + state, + &mut current_depth_node_diff_state, + ) { + DiffDetectionState::NoDiffDetected => DiffDetectionState::NoDiffDetected, + result @ (DiffDetectionState::HashDiffDetected + | DiffDetectionState::NodeTypesDiffer) => { + depth_states.push(current_depth_node_diff_state); + result + } + } } _ => unreachable!(), } @@ -425,17 +452,22 @@ const fn get_value_from_node(n: &Node) -> Option<&Vec> { #[cfg(test)] mod tests { - use super::{create_diff_between_tries, DiffPoint, NodeInfo, TriePath}; + use std::str::FromStr; + + use ethereum_types::BigEndianHash; + use rlp_derive::{RlpDecodable, RlpEncodable}; + + use super::create_full_diff_between_tries; use crate::{ + debug_tools::diff::{DiffPoint, NodeInfo}, nibbles::Nibbles, partial_trie::{HashedPartialTrie, PartialTrie}, trie_ops::TrieOpResult, - utils::TrieNodeType, + utils::{TrieNodeType, TriePath}, }; #[test] fn depth_single_node_hash_diffs_work() -> TrieOpResult<()> { - // TODO: Reduce duplication once we identify common structures across tests... let mut a = HashedPartialTrie::default(); a.insert(0x1234, vec![0])?; let a_hash = a.hash(); @@ -444,7 +476,7 @@ mod tests { b.insert(0x1234, vec![1])?; let b_hash = b.hash(); - let diff = create_diff_between_tries(&a, &b); + let diff = create_full_diff_between_tries(&a, &b); let expected_a = NodeInfo { key: 0x1234.into(), @@ -468,18 +500,90 @@ mod tests { b_info: expected_b, }; - assert_eq!(diff.latest_diff_res, Some(expected)); - + assert_eq!(diff.diff_points[0], expected); Ok(()) } - // TODO: Will finish these tests later (low-priority). #[test] - #[ignore] - fn depth_single_node_node_diffs_work() { - todo!() + fn depth_multi_node_diff_works() -> std::result::Result<(), Box> { + use ethereum_types::{H256, U256}; + use keccak_hash::keccak; + #[derive( + RlpEncodable, RlpDecodable, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, + )] + pub struct TestAccountRlp { + pub nonce: U256, + pub balance: U256, + pub storage_root: H256, + pub code_hash: H256, + } + + let mut data = vec![ + ( + keccak(hex::decode("f0d4c12a5768d806021f80a262b4d39d26c58b8d").unwrap()), + TestAccountRlp { + nonce: U256::from(1), + balance: U256::from(2), + storage_root: H256::from_uint(&1312378.into()), + code_hash: H256::from_uint(&943221.into()), + }, + ), + ( + keccak(hex::decode("95222290dd7278aa3ddd389cc1e1d165cc4bafe5").unwrap()), + TestAccountRlp { + nonce: U256::from(2), + balance: U256::from(3), + storage_root: H256::from_uint(&1123178.into()), + code_hash: H256::from_uint(&8133221.into()), + }, + ), + ( + keccak(hex::decode("43682bcf1ce452a70b72c109551084076c6377e0").unwrap()), + TestAccountRlp { + nonce: U256::from(100), + balance: U256::from(101), + storage_root: H256::from_uint(&12345678.into()), + code_hash: H256::from_uint(&94321.into()), + }, + ), + ( + keccak(hex::decode("97a9a15168c22b3c137e6381037e1499c8ad0978").unwrap()), + TestAccountRlp { + nonce: U256::from(3000), + balance: U256::from(3002), + storage_root: H256::from_uint(&123456781.into()), + code_hash: H256::from_uint(&943214141.into()), + }, + ), + ]; + + let create_trie_with_data = |trie: &Vec<(H256, TestAccountRlp)>| -> Result> { + let mut tr = HashedPartialTrie::default(); + tr.insert::(Nibbles::from_str(&hex::encode(trie[0].0.as_bytes()))?, rlp::encode(&trie[0].1).as_ref())?; + tr.insert::(Nibbles::from_str(&hex::encode(trie[1].0.as_bytes()))?, rlp::encode(&trie[1].1).as_ref())?; + tr.insert::(Nibbles::from_str(&hex::encode(trie[2].0.as_bytes()))?, rlp::encode(&trie[2].1).as_ref())?; + tr.insert::(Nibbles::from_str(&hex::encode(trie[3].0.as_bytes()))?, rlp::encode(&trie[3].1).as_ref())?; + Ok(tr) + }; + + let a = create_trie_with_data(&data)?; + + // Change data on multiple accounts + data[1].1.balance += U256::from(1); + data[3].1.nonce += U256::from(2); + data[3].1.storage_root = H256::from_uint(&4445556.into()); + let b = create_trie_with_data(&data)?; + + let diff = create_full_diff_between_tries(&a, &b); + + assert_eq!(diff.diff_points.len(), 2); + assert_eq!(&diff.diff_points[0].key.to_string(), "0x3"); + assert_eq!(&diff.diff_points[1].key.to_string(), "0x55"); + + Ok(()) } + // TODO: Will finish these tests later (low-priority). #[test] #[ignore] fn depth_multi_node_single_node_hash_diffs_work() { diff --git a/zero/src/trie_diff/mod.rs b/zero/src/trie_diff/mod.rs index f659dc86e..088e2d835 100644 --- a/zero/src/trie_diff/mod.rs +++ b/zero/src/trie_diff/mod.rs @@ -1,6 +1,6 @@ use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp}; use evm_arithmetization::generation::DebugOutputTries; -use mpt_trie::debug_tools::diff::{create_diff_between_tries, DiffPoint}; +use mpt_trie::debug_tools::diff::{create_full_diff_between_tries, DiffPoint}; use mpt_trie::utils::TrieNodeType; use tracing::info; @@ -24,13 +24,17 @@ pub fn compare_tries( V: rlp::Decodable + std::fmt::Debug, >( trie_name: &str, - diff_point: Option, + diff_point: Vec, block_number: u64, batch_index: usize, decode_key: bool, decode_data: bool, ) -> anyhow::Result<()> { - if let Some(ref trie_diff_point) = diff_point { + if diff_point.is_empty() { + info!("{trie_name} for block {block_number} batch {batch_index} matches."); + return Ok(()); + } + for (index, trie_diff_point) in diff_point.into_iter().enumerate() { if trie_diff_point.a_info.node_type == TrieNodeType::Leaf { if let Some(ref td_value) = trie_diff_point.a_info.value { let td_key_str: &str = if decode_key { @@ -73,19 +77,17 @@ pub fn compare_tries( } info!( - "{trie_name} block {block_number} batch {batch_index} diff: {:#?}", + "Diff {index} {trie_name} block {block_number} batch {batch_index} diff:\n{}\n", trie_diff_point ); - } else { - info!("{trie_name} for block {block_number} batch {batch_index} matches."); } Ok(()) } - let state_trie_diff = create_diff_between_tries(&left.state_trie, &right.state_trie); + let state_trie_diff = create_full_diff_between_tries(&left.state_trie, &right.state_trie); compare_tries_and_output_results::( "state trie", - state_trie_diff.latest_diff_res, + state_trie_diff.diff_points, block_number, batch_index, false, @@ -93,20 +95,20 @@ pub fn compare_tries( )?; let transaction_trie_diff = - create_diff_between_tries(&left.transaction_trie, &right.transaction_trie); + create_full_diff_between_tries(&left.transaction_trie, &right.transaction_trie); compare_tries_and_output_results::( "transaction trie", - transaction_trie_diff.latest_diff_res, + transaction_trie_diff.diff_points, block_number, batch_index, false, true, )?; - let receipt_trie_diff = create_diff_between_tries(&left.receipt_trie, &right.receipt_trie); + let receipt_trie_diff = create_full_diff_between_tries(&left.receipt_trie, &right.receipt_trie); compare_tries_and_output_results::( "receipt trie", - receipt_trie_diff.latest_diff_res, + receipt_trie_diff.diff_points, block_number, batch_index, true,