Skip to content

Commit

Permalink
feat: proof verification (#13)
Browse files Browse the repository at this point in the history
* feat: proof verification

* accept ref

* clippy

* iterate forward and return value

* error

* filter result properly

* address comments & change api

* remove clones

* remove dep
  • Loading branch information
rkrasiuk authored May 14, 2024
1 parent d980ebc commit 4fc6151
Show file tree
Hide file tree
Showing 10 changed files with 569 additions and 64 deletions.
40 changes: 7 additions & 33 deletions src/hash_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use super::{
nodes::{word_rlp, BranchNodeRef, ExtensionNodeRef, LeafNodeRef},
proof::ProofRetainer,
BranchNodeCompact, Nibbles, TrieMask, EMPTY_ROOT_HASH,
};
use crate::HashMap;
Expand All @@ -15,9 +16,6 @@ use alloc::{collections::BTreeMap, vec::Vec};
mod value;
pub use value::HashBuilderValue;

mod proof_retainer;
pub use proof_retainer::ProofRetainer;

/// A component used to construct the root hash of the trie. The primary purpose of a Hash Builder
/// is to build the Merkle proof that is essential for verifying the integrity and authenticity of
/// the trie's contents. It achieves this by constructing the root hash from the hashes of child
Expand Down Expand Up @@ -69,9 +67,9 @@ impl HashBuilder {
self
}

/// Enable proof retainer for the specified target nibbles.
pub fn with_proof_retainer(mut self, targets: Vec<Nibbles>) -> Self {
self.proof_retainer = Some(ProofRetainer::new(targets));
/// Enable specified proof retainer.
pub fn with_proof_retainer(mut self, retainer: ProofRetainer) -> Self {
self.proof_retainer = Some(retainer);
self
}

Expand Down Expand Up @@ -263,6 +261,7 @@ impl HashBuilder {
self.rlp_buf.clear();
hex::encode(&extension_node.rlp(&mut self.rlp_buf))
}, "extension node rlp");

self.rlp_buf.clear();
self.stack.push(extension_node.rlp(&mut self.rlp_buf));
self.retain_proof_from_buf(&current.slice(..len_from));
Expand Down Expand Up @@ -312,7 +311,7 @@ impl HashBuilder {
let state_mask = self.groups[len];
let hash_mask = self.hash_masks[len];
let branch_node = BranchNodeRef::new(&self.stack, &state_mask);
let children = branch_node.children(hash_mask);
let children = branch_node.child_hashes(hash_mask);

self.rlp_buf.clear();
let rlp = branch_node.rlp(&mut self.rlp_buf);
Expand Down Expand Up @@ -407,35 +406,10 @@ impl HashBuilder {
#[cfg(test)]
mod tests {
use super::*;
use crate::nodes::LeafNode;
use crate::{nodes::LeafNode, triehash_trie_root};
use alloy_primitives::{b256, hex, U256};
use alloy_rlp::Encodable;

fn triehash_trie_root<I, K, V>(iter: I) -> B256
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<[u8]> + Ord,
V: AsRef<[u8]>,
{
struct Keccak256Hasher;
impl hash_db::Hasher for Keccak256Hasher {
type Out = B256;
type StdHasher = plain_hasher::PlainHasher;

const LENGTH: usize = 32;

fn hash(x: &[u8]) -> Self::Out {
keccak256(x)
}
}

// We use `trie_root` instead of `sec_trie_root` because we assume
// the incoming keys are already hashed, which makes sense given
// we're going to be using the Hashed tables & pre-hash the data
// on the way in.
triehash::trie_root::<Keccak256Hasher, _, _, _>(iter)
}

// Hashes the keys, RLP encodes the values, compares the trie builder with the upstream root.
fn assert_hashed_trie_root<'a, I, K>(iter: I)
where
Expand Down
28 changes: 28 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub use nodes::BranchNodeCompact;
pub mod hash_builder;
pub use hash_builder::HashBuilder;

pub mod proof;

mod mask;
pub use mask::TrieMask;

Expand All @@ -43,3 +45,29 @@ pub use nybbles::{self, Nibbles};
/// Root hash of an empty trie.
pub const EMPTY_ROOT_HASH: alloy_primitives::B256 =
alloy_primitives::b256!("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421");

#[cfg(test)]
pub(crate) fn triehash_trie_root<I, K, V>(iter: I) -> alloy_primitives::B256
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<[u8]> + Ord,
V: AsRef<[u8]>,
{
struct Keccak256Hasher;
impl hash_db::Hasher for Keccak256Hasher {
type Out = alloy_primitives::B256;
type StdHasher = plain_hasher::PlainHasher;

const LENGTH: usize = 32;

fn hash(x: &[u8]) -> Self::Out {
alloy_primitives::keccak256(x)
}
}

// We use `trie_root` instead of `sec_trie_root` because we assume
// the incoming keys are already hashed, which makes sense given
// we're going to be using the Hashed tables & pre-hash the data
// on the way in.
triehash::trie_root::<Keccak256Hasher, _, _, _>(iter)
}
55 changes: 37 additions & 18 deletions src/nodes/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,20 @@ impl Decodable for BranchNode {
continue;
}

if bytes.len() < 32 {
return Err(alloy_rlp::Error::InputTooShort);
}

// Decode without advancing
let Header { payload_length, .. } = Header::decode(&mut &bytes[..])?;
let len = payload_length + length_of_length(payload_length);
stack.push(Vec::from(&bytes[..len]));
bytes.advance(len);
state_mask.set_bit(index);
stack.push(Vec::from(&bytes[..32]));
bytes.advance(32);
}

// Consume empty string code for branch node value.
bytes.advance(1);
let bytes = Header::decode_bytes(&mut bytes, false)?;
if !bytes.is_empty() {
return Err(alloy_rlp::Error::Custom("branch values not supported"));
}
debug_assert!(bytes.is_empty());

Ok(Self { stack, state_mask })
}
Expand Down Expand Up @@ -121,7 +124,7 @@ impl Encodable for BranchNodeRef<'_> {
// Advance the pointer to the next child.
stack_ptr += 1;
} else {
out.put_u8(EMPTY_STRING_CODE)
out.put_u8(EMPTY_STRING_CODE);
}
}

Expand Down Expand Up @@ -152,7 +155,7 @@ impl<'a> BranchNodeRef<'a> {

/// Given the hash and state mask of children present, return an iterator over the stack items
/// that match the mask.
pub fn children(&self, hash_mask: TrieMask) -> Vec<B256> {
pub fn child_hashes(&self, hash_mask: TrieMask) -> Vec<B256> {
let mut stack_ptr = self.first_child_index();
let mut children = Vec::with_capacity(hash_mask.count_ones() as usize);
for index in CHILD_INDEX_RANGE {
Expand Down Expand Up @@ -253,25 +256,41 @@ impl BranchNodeCompact {
#[cfg(test)]
mod tests {
use super::*;
use crate::nodes::{word_rlp, ExtensionNode, LeafNode};
use nybbles::Nibbles;

#[test]
fn rlp_branch_node_roundtrip() {
let empty = BranchNode::default();
let encoded = alloy_rlp::encode(&empty);
assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), empty);

let sparse_node = BranchNode::new(
vec![word_rlp(&B256::repeat_byte(1)), word_rlp(&B256::repeat_byte(2))],
TrieMask::new(0b1000100),
);
let encoded = alloy_rlp::encode(&sparse_node);
assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), sparse_node);

let leaf_child = LeafNode::new(Nibbles::from_nibbles(hex!("0203")), hex!("1234").to_vec());
let mut buf = vec![];
empty.encode(&mut buf);
assert_eq!(BranchNode::decode(&mut &buf[..]).unwrap(), empty);
let leaf_rlp = leaf_child.as_ref().rlp(&mut buf);
let branch_with_leaf = BranchNode::new(vec![leaf_rlp.clone()], TrieMask::new(0b0010));
let encoded = alloy_rlp::encode(&branch_with_leaf);
assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), branch_with_leaf);

let sparse_node = BranchNode::new(vec![vec![1; 32], vec![2; 32]], TrieMask::new(0b1000100));
let extension_child = ExtensionNode::new(Nibbles::from_nibbles(hex!("0203")), leaf_rlp);
let mut buf = vec![];
sparse_node.encode(&mut buf);
assert_eq!(BranchNode::decode(&mut &buf[..]).unwrap(), sparse_node);
let extension_rlp = extension_child.as_ref().rlp(&mut buf);
let branch_with_ext = BranchNode::new(vec![extension_rlp], TrieMask::new(0b00000100000));
let encoded = alloy_rlp::encode(&branch_with_ext);
assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), branch_with_ext);

let full = BranchNode::new(
core::iter::repeat(vec![0x23; 32]).take(16).collect(),
core::iter::repeat(word_rlp(&B256::repeat_byte(23))).take(16).collect(),
TrieMask::new(u16::MAX),
);
let mut buf = vec![];
full.encode(&mut buf);
assert_eq!(BranchNode::decode(&mut &buf[..]).unwrap(), full);
let encoded = alloy_rlp::encode(&full);
assert_eq!(BranchNode::decode(&mut &encoded[..]).unwrap(), full);
}
}
17 changes: 12 additions & 5 deletions src/nodes/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,15 @@ impl Encodable for ExtensionNode {
impl Decodable for ExtensionNode {
fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
let mut bytes = Header::decode_bytes(buf, true)?;

let encoded_key = Bytes::decode(&mut bytes)?;
if encoded_key.is_empty() {
return Err(alloy_rlp::Error::Custom("extension node key empty"));
}

// Retrieve first byte. If it's [Some], then the nibbles are odd.
let first = match encoded_key[0] & 0xf0 {
0x10 => Some(encoded_key[0] & 0x0f),
0x00 => None,
Self::ODD_FLAG => Some(encoded_key[0] & 0x0f),
Self::EVEN_FLAG => None,
_ => return Err(alloy_rlp::Error::Custom("node is not extension")),
};

Expand All @@ -63,6 +62,12 @@ impl Decodable for ExtensionNode {
}

impl ExtensionNode {
/// The flag representing the even number of nibbles in the extension key.
pub const EVEN_FLAG: u8 = 0x00;

/// The flag representing the odd number of nibbles in the extension key.
pub const ODD_FLAG: u8 = 0x10;

/// Creates a new extension node with the given key and a pointer to the child.
pub fn new(key: Nibbles, child: Vec<u8>) -> Self {
Self { key, child }
Expand Down Expand Up @@ -136,9 +141,11 @@ mod tests {
fn rlp_extension_node_roundtrip() {
let nibble = Nibbles::from_nibbles_unchecked(hex!("0604060f"));
let val = hex!("76657262");
let extension = ExtensionNode::new(nibble, val.to_vec());
let mut child = vec![];
val.to_vec().as_slice().encode(&mut child);
let extension = ExtensionNode::new(nibble, child);
let rlp = extension.as_ref().rlp(&mut vec![]);
assert_eq!(rlp, hex!("c88300646f76657262"));
assert_eq!(rlp, hex!("c98300646f8476657262"));
assert_eq!(ExtensionNode::decode(&mut &rlp[..]).unwrap(), extension);
}
}
11 changes: 8 additions & 3 deletions src/nodes/leaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,15 @@ impl Encodable for LeafNode {
impl Decodable for LeafNode {
fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
let mut bytes = Header::decode_bytes(buf, true)?;

let encoded_key = Bytes::decode(&mut bytes)?;
if encoded_key.is_empty() {
return Err(alloy_rlp::Error::Custom("leaf node key empty"));
}

// Retrieve first byte. If it's [Some], then the nibbles are odd.
let first = match encoded_key[0] & 0xf0 {
0x30 => Some(encoded_key[0] & 0x0f),
0x20 => None,
Self::ODD_FLAG => Some(encoded_key[0] & 0x0f),
Self::EVEN_FLAG => None,
_ => return Err(alloy_rlp::Error::Custom("node is not leaf")),
};

Expand All @@ -63,6 +62,12 @@ impl Decodable for LeafNode {
}

impl LeafNode {
/// The flag representing the even number of nibbles in the leaf key.
pub const EVEN_FLAG: u8 = 0x20;

/// The flag representing the odd number of nibbles in the leaf key.
pub const ODD_FLAG: u8 = 0x30;

/// Creates a new leaf node with the given key and value.
pub fn new(key: Nibbles, value: Vec<u8>) -> Self {
Self { key, value }
Expand Down
Loading

0 comments on commit 4fc6151

Please sign in to comment.