Skip to content

Commit

Permalink
feat: branch node decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
rkrasiuk committed May 9, 2024
1 parent a3288b7 commit 7939aa3
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 72 deletions.
18 changes: 9 additions & 9 deletions src/hash_builder/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//! The implementation of the hash builder.

use super::{
nodes::{word_rlp, BranchNode, ExtensionNodeRef, LeafNodeRef},
nodes::{word_rlp, BranchNodeRef, ExtensionNodeRef, LeafNodeRef},
BranchNodeCompact, Nibbles, TrieMask, EMPTY_ROOT_HASH,
};
use crate::HashMap;
use alloy_primitives::{keccak256, Bytes, B256};
use alloy_primitives::{hex, keccak256, Bytes, B256};
use core::cmp;
use tracing::trace;

Expand Down Expand Up @@ -106,7 +106,7 @@ impl HashBuilder {
pub fn print_stack(&self) {
println!("============ STACK ===============");
for item in &self.stack {
println!("{}", alloy_primitives::hex::encode(item));
println!("{}", hex::encode(item));
}
println!("============ END STACK ===============");
}
Expand Down Expand Up @@ -230,7 +230,7 @@ impl HashBuilder {
trace!(target: "trie::hash_builder", ?leaf_node, "pushing leaf node");
trace!(target: "trie::hash_builder", rlp = {
self.rlp_buf.clear();
alloy_primitives::hex::encode(&leaf_node.rlp(&mut self.rlp_buf))
hex::encode(&leaf_node.rlp(&mut self.rlp_buf))
}, "leaf node rlp");

self.rlp_buf.clear();
Expand Down Expand Up @@ -261,7 +261,7 @@ impl HashBuilder {
trace!(target: "trie::hash_builder", ?extension_node, "pushing extension node");
trace!(target: "trie::hash_builder", rlp = {
self.rlp_buf.clear();
alloy_primitives::hex::encode(&extension_node.rlp(&mut self.rlp_buf))
hex::encode(&extension_node.rlp(&mut self.rlp_buf))
}, "extension node rlp");
self.rlp_buf.clear();
self.stack.push(extension_node.rlp(&mut self.rlp_buf));
Expand Down Expand Up @@ -311,11 +311,11 @@ impl HashBuilder {
fn push_branch_node(&mut self, current: &Nibbles, len: usize) -> Vec<B256> {
let state_mask = self.groups[len];
let hash_mask = self.hash_masks[len];
let branch_node = BranchNode::new(&self.stack);
let children = branch_node.children(state_mask, hash_mask);
let branch_node = BranchNodeRef::new(&self.stack, &state_mask);
let children = branch_node.children(hash_mask);

self.rlp_buf.clear();
let rlp = branch_node.rlp(state_mask, &mut self.rlp_buf);
let rlp = branch_node.rlp(&mut self.rlp_buf);
self.retain_proof_from_buf(&current.slice(..len));

// Clears the stack from the branch node elements
Expand All @@ -329,7 +329,7 @@ impl HashBuilder {
self.stack.resize(first_child_idx, vec![]);

trace!(target: "trie::hash_builder", "pushing branch node with {:?} mask from stack", state_mask);
trace!(target: "trie::hash_builder", rlp = alloy_primitives::hex::encode(&rlp), "branch node rlp");
trace!(target: "trie::hash_builder", rlp = hex::encode(&rlp), "branch node rlp");
self.stack.push(rlp);
children
}
Expand Down
4 changes: 2 additions & 2 deletions src/hash_builder/value.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use alloy_primitives::B256;
use alloy_primitives::{hex, B256};
use core::fmt;

#[allow(unused_imports)]
Expand All @@ -18,7 +18,7 @@ pub enum HashBuilderValue {
impl fmt::Debug for HashBuilderValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Bytes(bytes) => write!(f, "Bytes({:?})", alloy_primitives::hex::encode(bytes)),
Self::Bytes(bytes) => write!(f, "Bytes({:?})", hex::encode(bytes)),
Self::Hash(hash) => write!(f, "Hash({:?})", hash),
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,9 @@ impl TrieMask {
pub const fn is_empty(self) -> bool {
self.0 == 0
}

/// Set bit at a specified index.
pub fn set_bit(&mut self, index: u8) {
self.0 = self.0 | (1u16 << index);
}
}
244 changes: 192 additions & 52 deletions src/nodes/branch.rs
Original file line number Diff line number Diff line change
@@ -1,76 +1,190 @@
use super::{super::TrieMask, rlp_node, CHILD_INDEX_RANGE};
use alloy_primitives::B256;
use alloy_rlp::{BufMut, EMPTY_STRING_CODE};
use alloy_primitives::{hex, B256};
use alloy_rlp::{length_of_length, Buf, BufMut, Decodable, Encodable, Header, EMPTY_STRING_CODE};
use core::fmt;

#[allow(unused_imports)]
use alloc::{collections::BTreeMap, vec::Vec};

/// A Branch node is only a pointer to the stack of nodes and is used to
/// create the RLP encoding of the node using masks which filter from
/// the stack of nodes.
#[derive(Clone, Debug)]
pub struct BranchNode<'a> {
/// Rlp encoded children
use alloc::vec::Vec;

/// A branch node in an Merkle Patricia Trie is a 17-element array consisting of 16 slots that
/// correspond to each hexadecimal character and an additional slot for a value. We do exclude
/// the node value since all paths have a fixed size.
#[derive(PartialEq, Eq, Default)]
pub struct BranchNode {
/// The collection of RLP encoded children.
pub stack: Vec<Vec<u8>>,
/// The bitmask indicating the presence of children at the respective nibble positions
pub state_mask: TrieMask,
}

impl fmt::Debug for BranchNode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BranchNode")
.field("stack", &self.stack.iter().map(|b| hex::encode(b)).collect::<Vec<_>>())
.field("state_mask", &self.state_mask)
.field("first_child_index", &self.as_ref().first_child_index())
.finish()
}
}

impl Encodable for BranchNode {
fn encode(&self, out: &mut dyn BufMut) {
self.as_ref().encode(out)
}

fn length(&self) -> usize {
self.as_ref().length()
}
}

impl Decodable for BranchNode {
fn decode(buf: &mut &[u8]) -> alloy_rlp::Result<Self> {
let mut bytes = Header::decode_bytes(buf, true)?;

let mut stack = Vec::new();
let mut state_mask = TrieMask::default();
for index in CHILD_INDEX_RANGE {
if bytes.len() < 1 {
return Err(alloy_rlp::Error::InputTooShort);
}

if bytes[0] == EMPTY_STRING_CODE {
bytes.advance(1);
continue;
}

if bytes.len() < 32 {
return Err(alloy_rlp::Error::InputTooShort);
}

state_mask.set_bit(index);
stack.push(Vec::from(&bytes[..32]));
bytes.advance(32);
}

// Consume empty string code for branch node value.
bytes.advance(1);

Ok(Self { stack, state_mask })
}
}

impl BranchNode {
/// Creates a new branch node with the given stack and state mask.
pub fn new(stack: Vec<Vec<u8>>, state_mask: TrieMask) -> Self {
Self { stack, state_mask }
}
/// Return branch node as [BranchNodeRef].
pub fn as_ref(&self) -> BranchNodeRef<'_> {
BranchNodeRef::new(&self.stack, &self.state_mask)
}
}

/// A reference to [BranchNode] and its state mask.
/// NOTE: The stack may contain more items that specified in the state mask.
#[derive(Clone)]
pub struct BranchNodeRef<'a> {
/// Reference to the collection of RLP encoded nodes.
/// NOTE: The referenced stack might have more items than the number of children
/// for this node. We should only ever access items starting from
/// [BranchNodeRef::first_child_index].
pub stack: &'a [Vec<u8>],
/// Reference to bitmask indicating the presence of children at
/// the respective nibble positions.
pub state_mask: &'a TrieMask,
}

impl fmt::Debug for BranchNodeRef<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BranchNodeRef")
.field("stack", &self.stack.iter().map(|b| hex::encode(b)).collect::<Vec<_>>())
.field("state_mask", &self.state_mask)
.field("first_child_index", &self.first_child_index())
.finish()
}
}

/// Implementation of RLP encoding for branch node in Ethereum Merkle Patricia Trie.
/// Encode it as a 17-element list consisting of 16 slots that correspond to
/// each child of the node (0-f) and an additional slot for a value.
impl Encodable for BranchNodeRef<'_> {
fn encode(&self, out: &mut dyn BufMut) {
Header { list: true, payload_length: self.rlp_payload_length() }.encode(out);

// Extend the RLP buffer with the present children
let mut stack_ptr = self.first_child_index();
for index in CHILD_INDEX_RANGE {
if self.state_mask.is_bit_set(index) {
out.put_slice(&self.stack[stack_ptr]);
// Advance the pointer to the next child.
stack_ptr += 1;
} else {
out.put_u8(EMPTY_STRING_CODE)
}
}

out.put_u8(EMPTY_STRING_CODE);
}

fn length(&self) -> usize {
let payload_length = self.rlp_payload_length();
payload_length + length_of_length(payload_length)
}
}

impl<'a> BranchNode<'a> {
impl<'a> BranchNodeRef<'a> {
/// Create a new branch node from the stack of nodes.
pub const fn new(stack: &'a [Vec<u8>]) -> Self {
Self { stack }
pub fn new(stack: &'a [Vec<u8>], state_mask: &'a TrieMask) -> Self {
Self { stack, state_mask }
}

/// Returns the stack index of the first child for this node.
///
/// # Panics
///
/// If the stack length is less than number of children specified in state mask.
/// Means that the node is in inconsistent state.
pub fn first_child_index(&self) -> usize {
self.stack.len().checked_sub(self.state_mask.count_ones() as usize).unwrap()
}

/// Given the hash and state mask of children present, return an iterator over the stack items
/// that match the mask.
pub fn children(&self, state_mask: TrieMask, hash_mask: TrieMask) -> Vec<B256> {
let mut index = self.stack.len() - state_mask.count_ones() as usize;
let mut children = Vec::with_capacity(CHILD_INDEX_RANGE.len());
for digit in CHILD_INDEX_RANGE {
if state_mask.is_bit_set(digit) {
if hash_mask.is_bit_set(digit) {
children.push(B256::from_slice(&self.stack[index][1..]));
pub fn children(&self, hash_mask: TrieMask) -> Vec<B256> {
let mut stack_ptr = self.first_child_index();
let mut children = Vec::with_capacity(hash_mask.count_ones() as usize);
for index in CHILD_INDEX_RANGE {
if self.state_mask.is_bit_set(index) {
if hash_mask.is_bit_set(index) {
children.push(B256::from_slice(&self.stack[stack_ptr][1..]));
}
index += 1;
stack_ptr += 1;
}
}
children
}

/// Returns the RLP encoding of the branch node given the state mask of children present.
pub fn rlp(&self, state_mask: TrieMask, buf: &mut Vec<u8>) -> Vec<u8> {
let first_child_idx = self.stack.len() - state_mask.count_ones() as usize;

// Create the RLP header from the mask elements present.
let mut i = first_child_idx;
let header = CHILD_INDEX_RANGE.fold(
alloy_rlp::Header { list: true, payload_length: 1 },
|mut header, digit| {
if state_mask.is_bit_set(digit) {
header.payload_length += self.stack[i].len();
i += 1;
} else {
header.payload_length += 1;
}
header
},
);
header.encode(buf);
pub fn rlp(&self, out: &mut Vec<u8>) -> Vec<u8> {
self.encode(out);
rlp_node(out)
}

// Extend the RLP buffer with the present children
let mut i = first_child_idx;
CHILD_INDEX_RANGE.for_each(|idx| {
if state_mask.is_bit_set(idx) {
buf.extend_from_slice(&self.stack[i]);
i += 1;
/// Returns the length of RLP encoded fields of branch node.
fn rlp_payload_length(&self) -> usize {
let mut payload_length = 1;

let mut stack_ptr = self.first_child_index();
for digit in CHILD_INDEX_RANGE {
if self.state_mask.is_bit_set(digit) {
payload_length += self.stack[stack_ptr].len();
// Advance the pointer to the next child.
stack_ptr += 1;
} else {
buf.put_u8(EMPTY_STRING_CODE)
payload_length += 1;
}
});

// Is this needed?
buf.put_u8(EMPTY_STRING_CODE);

rlp_node(buf)
}
payload_length
}
}

Expand Down Expand Up @@ -133,3 +247,29 @@ impl BranchNodeCompact {
self.hashes[index as usize]
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn rlp_branch_node_roundtrip() {
let empty = BranchNode::default();
let mut buf = vec![];
empty.encode(&mut buf);
assert_eq!(BranchNode::decode(&mut &buf[..]).unwrap(), empty);

let sparse_node = BranchNode::new(vec![vec![1; 32], vec![2; 32]], TrieMask::new(0b1000100));
let mut buf = vec![];
sparse_node.encode(&mut buf);
assert_eq!(BranchNode::decode(&mut &buf[..]).unwrap(), sparse_node);

let full = BranchNode::new(
std::iter::repeat(vec![0x23; 32]).take(16).collect(),
TrieMask::new(u16::MAX),
);
let mut buf = vec![];
full.encode(&mut buf);
assert_eq!(BranchNode::decode(&mut &buf[..]).unwrap(), full);
}
}
Loading

0 comments on commit 7939aa3

Please sign in to comment.