Skip to content

Commit

Permalink
iterate forward and return value
Browse files Browse the repository at this point in the history
  • Loading branch information
rkrasiuk committed May 13, 2024
1 parent dcff118 commit 67051af
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 112 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ proptest = { version = "1.4", optional = true }
proptest-derive = { version = "0.4", optional = true }

[dev-dependencies]
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
alloy-rlp = { version = "0.3", default-features = false, features = ["derive", "arrayvec"] }
hash-db = "0.15"
plain_hasher = "0.2"
triehash = "0.8.4"
Expand Down
36 changes: 4 additions & 32 deletions src/proof/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,7 @@ pub enum ProofVerificationError {
/// Value in the proof.
got: Box<Bytes>,
/// Expected value.
expected: Box<Bytes>,
},
/// Unexpected key encountered in proof during verification.
UnexpectedKey {
/// Path at which unexpected key was encountered.
path: Box<Nibbles>,
/// Unexpected key. Empty means entry is missing from branch node at given path.
key: Box<Nibbles>,
},
/// Branch node child is missing at specified path.
MissingBranchChild {
/// Full path at which child is missing.
path: Box<Nibbles>,
expected: Option<Box<Bytes>>,
},
/// Error during RLP decoding of trie node.
Rlp(alloy_rlp::Error),
Expand Down Expand Up @@ -60,13 +48,7 @@ impl fmt::Display for ProofVerificationError {
write!(f, "root mismatch. got: {got}. expected: {expected}")
}
ProofVerificationError::ValueMismatch { path, got, expected } => {
write!(f, "value mismatch at path {path:?}. got: {got}. expected: {expected}")
}
ProofVerificationError::UnexpectedKey { path, key } => {
write!(f, "unexpected node key {key:?} at path {path:?}")
}
ProofVerificationError::MissingBranchChild { path } => {
write!(f, "missing branch child at path {path:?}")
write!(f, "value mismatch at path {path:?}. got: {got}. expected: {expected:?}")
}
ProofVerificationError::Rlp(error) => fmt::Display::fmt(error, f),
}
Expand All @@ -81,21 +63,11 @@ impl From<alloy_rlp::Error> for ProofVerificationError {

impl ProofVerificationError {
/// Create [ProofVerificationError::ValueMismatch] error variant.
pub fn value_mismatch(path: Nibbles, got: Bytes, expected: Bytes) -> Self {
pub fn value_mismatch(path: Nibbles, got: Bytes, expected: Option<Bytes>) -> Self {
Self::ValueMismatch {
path: Box::new(path),
got: Box::new(got),
expected: Box::new(expected),
expected: expected.map(Box::new),
}
}

/// Create [ProofVerificationError::UnexpectedKey] error variant.
pub fn unexpected_key(path: Nibbles, key: Nibbles) -> Self {
Self::UnexpectedKey { path: Box::new(path), key: Box::new(key) }
}

/// Create [ProofVerificationError::MissingBranchChild] error variant.
pub fn missing_branch_child(path: Nibbles) -> Self {
Self::MissingBranchChild { path: Box::new(path) }
}
}
157 changes: 78 additions & 79 deletions src/proof/verify.rs
Original file line number Diff line number Diff line change
@@ -1,104 +1,77 @@
//! Proof verification logic.

use crate::{
nodes::{rlp_node, TrieNode, CHILD_INDEX_RANGE},
nodes::{rlp_node, word_rlp, TrieNode, CHILD_INDEX_RANGE},
proof::ProofVerificationError,
EMPTY_ROOT_HASH,
};
use alloc::vec::Vec;
use alloy_primitives::{keccak256, Bytes, B256};
use alloy_primitives::{Bytes, B256};
use alloy_rlp::Decodable;
use nybbles::Nibbles;

/// Verify the proof for given key value pair against the provided state root.
/// Returns the leaf node value for the given key.
pub fn verify_proof<'a, I>(
proof: I,
root: B256,
key: B256,
value: Vec<u8>,
) -> Result<(), ProofVerificationError>
) -> Result<Option<Vec<u8>>, ProofVerificationError>
where
I: IntoIterator<Item = &'a Bytes>,
I::IntoIter: DoubleEndedIterator,
{
let mut proof = proof.into_iter().rev().peekable();

if root == EMPTY_ROOT_HASH && proof.peek().is_none() {
return Ok(());
let mut proof = proof.into_iter().peekable();

if proof.peek().is_none() {
return if root == EMPTY_ROOT_HASH {
Ok(None)
} else {
return Err(ProofVerificationError::RootMismatch {
got: EMPTY_ROOT_HASH,
expected: root,
});
};
}

let mut target = Nibbles::unpack(key);
let mut expected_value = value;

let target = Nibbles::unpack(key);
let mut walked_path = Nibbles::default();
let mut expected_value = Some(word_rlp(&root));
for node in proof {
let nibbles_verified = match TrieNode::decode(&mut &node[..])? {
TrieNode::Branch(branch) => {
let value = 'val: {
if let Some(last) = target.last() {
let mut stack_ptr = branch.as_ref().first_child_index();
for index in CHILD_INDEX_RANGE {
if branch.state_mask.is_bit_set(index) {
if index == last {
break 'val &branch.stack[stack_ptr];
}
stack_ptr += 1;
if Some(rlp_node(&node)) != expected_value {
let got = Bytes::copy_from_slice(&node);
let expected = expected_value.map(|b| Bytes::copy_from_slice(&b));
return Err(ProofVerificationError::value_mismatch(walked_path, got, expected));
}

expected_value = match TrieNode::decode(&mut &node[..])? {
TrieNode::Branch(branch) => 'val: {
if let Some(next) = target.get(walked_path.len()) {
let mut stack_ptr = branch.as_ref().first_child_index();
for index in CHILD_INDEX_RANGE {
if branch.state_mask.is_bit_set(index) {
if index == *next {
walked_path.push(*next);
break 'val Some(branch.stack[stack_ptr].clone());
}
stack_ptr += 1;
}
}

return Err(ProofVerificationError::missing_branch_child(target));
};

if value != &expected_value {
let got = Bytes::copy_from_slice(value.as_slice());
let expected = Bytes::from(expected_value);
return Err(ProofVerificationError::value_mismatch(target, got, expected));
}

1
None
}
TrieNode::Extension(extension) => {
if !target.ends_with(&extension.key) {
return Err(ProofVerificationError::unexpected_key(target, extension.key));
}

if extension.child != expected_value {
let got = Bytes::copy_from_slice(extension.child.as_slice());
let expected = Bytes::from(expected_value);
return Err(ProofVerificationError::value_mismatch(target, got, expected));
}

extension.key.len()
walked_path.extend_from_slice(&extension.key);
Some(extension.child).filter(|_| target.starts_with(&walked_path))
}
TrieNode::Leaf(leaf) => {
if !target.ends_with(&leaf.key) {
return Err(ProofVerificationError::unexpected_key(target, leaf.key));
}

if leaf.value != expected_value {
let got = Bytes::copy_from_slice(leaf.value.as_slice());
let expected = Bytes::from(expected_value);
return Err(ProofVerificationError::value_mismatch(target, got, expected));
}

leaf.key.len()
walked_path.extend_from_slice(&leaf.key);
Some(leaf.value.clone()).filter(|_| target.starts_with(&walked_path))
}
};
target.truncate(target.len() - nibbles_verified);
expected_value = rlp_node(node);
}

let computed_root = if expected_value.len() == B256::len_bytes() + 1 {
B256::from_slice(&expected_value[1..])
} else {
keccak256(expected_value)
};

if root == computed_root {
Ok(())
} else {
Err(ProofVerificationError::RootMismatch { got: computed_root, expected: root })
}
Ok(expected_value)
}

#[cfg(test)]
Expand All @@ -113,34 +86,60 @@ mod tests {
let mut hash_builder = HashBuilder::default().with_proof_retainer(ProofRetainer::default());
let root = hash_builder.root();
let proof = hash_builder.take_proofs();
assert_eq!(verify_proof(proof.values(), root, key, vec![]), Ok(()));
assert_eq!(verify_proof(proof.values(), root, key), Ok(None));

let mut dummy_proof = vec![];
BranchNode::default().encode(&mut dummy_proof);
assert_eq!(
verify_proof([&Bytes::from(dummy_proof)], root, key, vec![]),
Err(ProofVerificationError::missing_branch_child(Nibbles::unpack(key)))
verify_proof([&Bytes::from(dummy_proof.clone())], root, key),
Err(ProofVerificationError::value_mismatch(
Nibbles::default(),
Bytes::from(dummy_proof),
Some(Bytes::from(word_rlp(&EMPTY_ROOT_HASH)))
))
);
}

#[test]
fn single_leaf_trie_proof_verifcation() {
fn single_leaf_trie_proof_verification() {
let target = B256::with_last_byte(0x2);
let non_existent_target = B256::with_last_byte(0x3);

let retainer = ProofRetainer::from_iter([target].map(Nibbles::unpack));
let retainer = ProofRetainer::from_iter([target, non_existent_target].map(Nibbles::unpack));
let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
hash_builder.add_leaf(Nibbles::unpack(target), &target[..]);
let root = hash_builder.root();
assert_eq!(root, triehash_trie_root([(target, target)]));

let proof = hash_builder.take_proofs();
assert_eq!(verify_proof(proof.values(), root, target, target.to_vec()), Ok(()));
assert_eq!(verify_proof(proof.values(), root, target), Ok(Some(target.to_vec())));
}

#[test]
fn non_existent_proof_verification() {
let range = 0..=0xf;
let target = B256::with_last_byte(0xff);

let retainer = ProofRetainer::from_iter([target].map(Nibbles::unpack));
let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
for key in range.clone() {
let hash = B256::with_last_byte(key);
hash_builder.add_leaf(Nibbles::unpack(hash), &hash[..]);
}
let root = hash_builder.root();
assert_eq!(
root,
triehash_trie_root(range.map(|b| (B256::with_last_byte(b), B256::with_last_byte(b))))
);

let proof = hash_builder.take_proofs();
assert_eq!(verify_proof(proof.values(), root, target), Ok(None));
}

#[test]
fn extension_root_trie_proof_verification() {
let range = 0..=0xf; // 0xff
let target = B256::with_last_byte(0x2); // 0x42
let range = 0..=0xff;
let target = B256::with_last_byte(0x42);

let retainer = ProofRetainer::from_iter([target].map(Nibbles::unpack));
let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
Expand All @@ -155,7 +154,7 @@ mod tests {
);

let proof = hash_builder.take_proofs();
assert_eq!(verify_proof(proof.values(), root, target, target.to_vec()), Ok(()));
assert_eq!(verify_proof(proof.values(), root, target), Ok(Some(target.to_vec())));
}

#[test]
Expand All @@ -180,11 +179,11 @@ mod tests {

let proof1 =
proof.iter().filter_map(|(k, v)| Nibbles::unpack(target1).starts_with(k).then_some(v));
assert_eq!(verify_proof(proof1, root, target1, target1.to_vec()), Ok(()));
assert_eq!(verify_proof(proof1, root, target1), Ok(Some(target1.to_vec())));

let proof2 =
proof.iter().filter_map(|(k, v)| Nibbles::unpack(target2).starts_with(k).then_some(v));
assert_eq!(verify_proof(proof2, root, target2, target2.to_vec()), Ok(()));
assert_eq!(verify_proof(proof2, root, target2), Ok(Some(target2.to_vec())));
}

#[test]
Expand All @@ -211,7 +210,7 @@ mod tests {
let proofs = hash_builder.take_proofs();
for (key, value) in hashed {
let proof = proofs.iter().filter_map(|(k, v)| Nibbles::unpack(key).starts_with(k).then_some(v));
assert_eq!(verify_proof(proof, root, key, value), Ok(()));
assert_eq!(verify_proof(proof, root, key), Ok(Some(value)));
}
});
}
Expand Down

0 comments on commit 67051af

Please sign in to comment.