Skip to content

Commit

Permalink
feat: full tries comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
atanmarko committed Sep 24, 2024
1 parent 8825ea8 commit 8473d71
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 64 deletions.
208 changes: 156 additions & 52 deletions mpt_trie/src/debug_tools/diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
//! - Top-down will find the highest point of a structural divergence and report
//! it. If there are multiple divergences, then only the one that is the
//! highest in the trie will be reported.
//! - Bottom-up (not implemented) is a lot more complex to implement, but will
//! attempt to find the smallest structural trie difference between the trie.
//! If there are multiple differences, then this will likely be what you want
//! to use.
//! - Bottom-up is a lot more complex to implement, but will attempt to find the
//! smallest structural trie difference between the trie. If there are
//! multiple differences, then this will likely be what you want to use.

use std::fmt::{self, Debug};
use std::{fmt::Display, ops::Deref};
Expand All @@ -48,17 +47,16 @@ fn get_key_piece_from_node<T: PartialTrie>(n: &Node<T>) -> Nibbles {

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
/// The difference between two Tries, represented as the highest
/// point of a structural divergence.
/// array of `DiffPoint`s.
pub struct TrieDiff {
/// The highest point of structural divergence.
pub latest_diff_res: Option<DiffPoint>,
// TODO: Later add a second pass for finding diffs from the bottom up (`earliest_diff_res`).
/// Diff points between the two tries.
pub diff_points: Vec<DiffPoint>,
}

impl Display for TrieDiff {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(diff) = &self.latest_diff_res {
write!(f, "{}", diff)?;
for (index, diff_point) in self.diff_points.iter().enumerate() {
writeln!(f, "{}: {}\n", index, diff_point)?;
}

Ok(())
Expand Down Expand Up @@ -176,32 +174,30 @@ impl NodeInfo {
}

/// Create a diff between two tries. Will perform both types of diff searches
/// (top-down & bottom-up).
pub fn create_diff_between_tries(a: &HashedPartialTrie, b: &HashedPartialTrie) -> TrieDiff {
/// (top-down & bottom-up). It will not stop on first difference.
pub fn create_full_diff_between_tries(a: &HashedPartialTrie, b: &HashedPartialTrie) -> TrieDiff {
TrieDiff {
latest_diff_res: find_latest_diff_point_between_tries(a, b),
diff_points: find_all_diff_points_between_tries(a, b),
}
}

// Only support `HashedPartialTrie` due to it being significantly faster to
// detect differences because of caching hashes.
fn find_latest_diff_point_between_tries(
fn find_all_diff_points_between_tries(
a: &HashedPartialTrie,
b: &HashedPartialTrie,
) -> Option<DiffPoint> {
) -> Vec<DiffPoint> {
let state = DepthDiffPerCallState::new(a, b, Nibbles::default(), 0);
let mut longest_state = DepthNodeDiffState::default();

find_latest_diff_point_between_tries_rec(&state, &mut longest_state);

// If there was a node diff, we always want to prioritize displaying this over a
// hash diff. The reasoning behind this is hash diffs can become sort of
// meaningless or misleading if the trie diverges at some point (eg. saying
// there is a hash diff deep in two separate trie structures doesn't make much
// sense).
longest_state
.longest_key_node_diff
.or(longest_state.longest_key_hash_diff)
let mut longest_states = Vec::new();

find_all_diff_points_between_tries_rec(&state, &mut longest_states);

longest_states
.into_iter()
.filter_map(|longest_state| {
longest_state
.longest_key_node_diff
.or(longest_state.longest_key_hash_diff)
})
.collect()
}

#[derive(Debug, Default)]
Expand Down Expand Up @@ -298,9 +294,11 @@ impl<'a> DepthDiffPerCallState<'a> {
}
}

fn find_latest_diff_point_between_tries_rec(
// Search for the differences between two tries. Do not stop on first
// difference.
fn find_all_diff_points_between_tries_rec(
state: &DepthDiffPerCallState,
depth_state: &mut DepthNodeDiffState,
depth_states: &mut Vec<DepthNodeDiffState>,
) -> DiffDetectionState {
let a_hash = state.a.hash();
let b_hash = state.b.hash();
Expand All @@ -319,19 +317,28 @@ fn find_latest_diff_point_between_tries_rec(

// Note that differences in a node's `value` will be picked up by a hash
// mismatch.
let mut current_depth_node_diff_state: DepthNodeDiffState = Default::default();
if (a_type, a_key_piece) != (b_type, b_key_piece) {
depth_state.try_update_longest_divergence_key_node(state);
current_depth_node_diff_state.try_update_longest_divergence_key_node(state);
depth_states.push(current_depth_node_diff_state);
DiffDetectionState::NodeTypesDiffer
} else {
match (&state.a.node, &state.b.node) {
(Node::Empty, Node::Empty) => DiffDetectionState::NoDiffDetected,
(Node::Hash(a_hash), Node::Hash(b_hash)) => {
create_diff_detection_state_based_from_hashes(
match create_diff_detection_state_based_from_hashes(
a_hash,
b_hash,
&state.new_from_parent(state.a, state.b, &Nibbles::default()),
depth_state,
)
&mut current_depth_node_diff_state,
) {
DiffDetectionState::NoDiffDetected => DiffDetectionState::NoDiffDetected,
result @ (DiffDetectionState::HashDiffDetected
| DiffDetectionState::NodeTypesDiffer) => {
depth_states.push(current_depth_node_diff_state);
result
}
}
}
(
Node::Branch {
Expand All @@ -346,13 +353,13 @@ fn find_latest_diff_point_between_tries_rec(
let mut most_significant_diff_found = DiffDetectionState::NoDiffDetected;

for i in 0..16_usize {
let res = find_latest_diff_point_between_tries_rec(
let res = find_all_diff_points_between_tries_rec(
&state.new_from_parent(
&a_children[i],
&b_children[i],
&Nibbles::from_nibble(i as u8),
),
depth_state,
depth_states,
);
most_significant_diff_found =
most_significant_diff_found.pick_most_significant_state(&res);
Expand All @@ -364,8 +371,18 @@ fn find_latest_diff_point_between_tries_rec(
) {
most_significant_diff_found
} else {
// Also run a hash check if we haven't picked anything up yet.
create_diff_detection_state_based_from_hash_and_gen_hashes(state, depth_state)
// Also run a hash check if we haven't picked anything up
match create_diff_detection_state_based_from_hash_and_gen_hashes(
state,
&mut current_depth_node_diff_state,
) {
DiffDetectionState::NoDiffDetected => DiffDetectionState::NoDiffDetected,
result @ (DiffDetectionState::HashDiffDetected
| DiffDetectionState::NodeTypesDiffer) => {
depth_states.push(current_depth_node_diff_state);
result
}
}
}
}
(
Expand All @@ -377,12 +394,22 @@ fn find_latest_diff_point_between_tries_rec(
nibbles: _b_nibs,
child: b_child,
},
) => find_latest_diff_point_between_tries_rec(
) => find_all_diff_points_between_tries_rec(
&state.new_from_parent(a_child, b_child, a_nibs),
depth_state,
depth_states,
),
(Node::Leaf { .. }, Node::Leaf { .. }) => {
create_diff_detection_state_based_from_hash_and_gen_hashes(state, depth_state)
match create_diff_detection_state_based_from_hash_and_gen_hashes(
state,
&mut current_depth_node_diff_state,
) {
DiffDetectionState::NoDiffDetected => DiffDetectionState::NoDiffDetected,
result @ (DiffDetectionState::HashDiffDetected
| DiffDetectionState::NodeTypesDiffer) => {
depth_states.push(current_depth_node_diff_state);
result
}
}
}
_ => unreachable!(),
}
Expand Down Expand Up @@ -425,17 +452,22 @@ const fn get_value_from_node<T: PartialTrie>(n: &Node<T>) -> Option<&Vec<u8>> {

#[cfg(test)]
mod tests {
use super::{create_diff_between_tries, DiffPoint, NodeInfo, TriePath};
use std::str::FromStr;

use ethereum_types::BigEndianHash;
use rlp_derive::{RlpDecodable, RlpEncodable};

use super::create_full_diff_between_tries;
use crate::{
debug_tools::diff::{DiffPoint, NodeInfo},
nibbles::Nibbles,
partial_trie::{HashedPartialTrie, PartialTrie},
trie_ops::TrieOpResult,
utils::TrieNodeType,
utils::{TrieNodeType, TriePath},
};

#[test]
fn depth_single_node_hash_diffs_work() -> TrieOpResult<()> {
// TODO: Reduce duplication once we identify common structures across tests...
let mut a = HashedPartialTrie::default();
a.insert(0x1234, vec![0])?;
let a_hash = a.hash();
Expand All @@ -444,7 +476,7 @@ mod tests {
b.insert(0x1234, vec![1])?;
let b_hash = b.hash();

let diff = create_diff_between_tries(&a, &b);
let diff = create_full_diff_between_tries(&a, &b);

let expected_a = NodeInfo {
key: 0x1234.into(),
Expand All @@ -468,18 +500,90 @@ mod tests {
b_info: expected_b,
};

assert_eq!(diff.latest_diff_res, Some(expected));

assert_eq!(diff.diff_points[0], expected);
Ok(())
}

// TODO: Will finish these tests later (low-priority).
#[test]
#[ignore]
fn depth_single_node_node_diffs_work() {
todo!()
fn depth_multi_node_diff_works() -> std::result::Result<(), Box<dyn std::error::Error>> {
use ethereum_types::{H256, U256};
use keccak_hash::keccak;
#[derive(
RlpEncodable, RlpDecodable, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord,
)]
pub struct TestAccountRlp {
pub nonce: U256,
pub balance: U256,
pub storage_root: H256,
pub code_hash: H256,
}

let mut data = vec![
(
keccak(hex::decode("f0d4c12a5768d806021f80a262b4d39d26c58b8d").unwrap()),
TestAccountRlp {
nonce: U256::from(1),
balance: U256::from(2),
storage_root: H256::from_uint(&1312378.into()),
code_hash: H256::from_uint(&943221.into()),
},
),
(
keccak(hex::decode("95222290dd7278aa3ddd389cc1e1d165cc4bafe5").unwrap()),
TestAccountRlp {
nonce: U256::from(2),
balance: U256::from(3),
storage_root: H256::from_uint(&1123178.into()),
code_hash: H256::from_uint(&8133221.into()),
},
),
(
keccak(hex::decode("43682bcf1ce452a70b72c109551084076c6377e0").unwrap()),
TestAccountRlp {
nonce: U256::from(100),
balance: U256::from(101),
storage_root: H256::from_uint(&12345678.into()),
code_hash: H256::from_uint(&94321.into()),
},
),
(
keccak(hex::decode("97a9a15168c22b3c137e6381037e1499c8ad0978").unwrap()),
TestAccountRlp {
nonce: U256::from(3000),
balance: U256::from(3002),
storage_root: H256::from_uint(&123456781.into()),
code_hash: H256::from_uint(&943214141.into()),
},
),
];

let create_trie_with_data = |trie: &Vec<(H256, TestAccountRlp)>| -> Result<HashedPartialTrie, Box<dyn std::error::Error>> {
let mut tr = HashedPartialTrie::default();
tr.insert::<Nibbles, &[u8]>(Nibbles::from_str(&hex::encode(trie[0].0.as_bytes()))?, rlp::encode(&trie[0].1).as_ref())?;
tr.insert::<Nibbles, &[u8]>(Nibbles::from_str(&hex::encode(trie[1].0.as_bytes()))?, rlp::encode(&trie[1].1).as_ref())?;
tr.insert::<Nibbles, &[u8]>(Nibbles::from_str(&hex::encode(trie[2].0.as_bytes()))?, rlp::encode(&trie[2].1).as_ref())?;
tr.insert::<Nibbles, &[u8]>(Nibbles::from_str(&hex::encode(trie[3].0.as_bytes()))?, rlp::encode(&trie[3].1).as_ref())?;
Ok(tr)
};

let a = create_trie_with_data(&data)?;

// Change data on multiple accounts
data[1].1.balance += U256::from(1);
data[3].1.nonce += U256::from(2);
data[3].1.storage_root = H256::from_uint(&4445556.into());
let b = create_trie_with_data(&data)?;

let diff = create_full_diff_between_tries(&a, &b);

assert_eq!(diff.diff_points.len(), 2);
assert_eq!(&diff.diff_points[0].key.to_string(), "0x3");
assert_eq!(&diff.diff_points[1].key.to_string(), "0x55");

Ok(())
}

// TODO: Will finish these tests later (low-priority).
#[test]
#[ignore]
fn depth_multi_node_single_node_hash_diffs_work() {
Expand Down
Loading

0 comments on commit 8473d71

Please sign in to comment.