Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: compute storage root #1294

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions src/utils/mpt/nibbles.cairo
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstrings

Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from utils.utils import Helpers
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.math import unsigned_div_rem

struct Nibbles {
nibbles_len: felt,
nibbles: felt*,
}

namespace NibblesImpl {
func from_bytes{range_check_ptr}(bytes_len: felt, bytes: felt*) -> Nibbles* {
alloc_locals;
local nibbles_len = bytes_len * 2;
let (local output: felt*) = alloc();

if (nibbles_len == 0) {
tempvar res = new Nibbles(nibbles_len, output);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in done label: new Nibbles(nibbles_len, nibbles);
Maybe standardize the naming of what Nibbles takes by renaming output to nibbles or changing the naming in done label?

return res;
}

tempvar range_check_ptr = range_check_ptr;
tempvar output = output;
tempvar value = bytes[0];

unpack_byte:
let range_check_ptr = [ap - 3];
let output = cast([ap - 2], felt*);
let value = [ap - 1];
let base = 0x10;
let bound = 0x10;
let (high, _) = unsigned_div_rem(value, base);
assert [output] = high;
let output = output + 1;
%{
memory[ids.output] = res = (int(ids.value) % PRIME) % ids.base
assert res < ids.bound, f'split_int(): Limb {res} is out of range.'
%}
let output = output + 1;
Comment on lines +29 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just ?

Suggested change
let base = 0x10;
let bound = 0x10;
let (high, _) = unsigned_div_rem(value, base);
assert [output] = high;
let output = output + 1;
%{
memory[ids.output] = res = (int(ids.value) % PRIME) % ids.base
assert res < ids.bound, f'split_int(): Limb {res} is out of range.'
%}
let output = output + 1;
let (high, low) = unsigned_div_rem(value, 0x10);
assert [output] = high;
assert [output + 1] = low;
let output = output + 2;


let nibbles_len = [fp];
let output_start = cast([fp + 1], felt*);
let count = output - output_start;
let is_done = Helpers.is_zero(nibbles_len - count);
jmp done if is_done != 0;

let next_byte_index = count / 2;
let bytes = cast([fp - 3], felt*);
tempvar value = bytes[next_byte_index];
let range_check_ptr = range_check_ptr + 1;
[ap] = range_check_ptr, ap++;
[ap] = output, ap++;
[ap] = value, ap++;

Comment on lines +40 to +53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be simplified to avoid the is_zero calls especially

jmp unpack_byte;

done:
let nibbles_len = [fp];
let nibbles = cast([fp + 1], felt*);

tempvar res = new Nibbles(nibbles_len, nibbles);
return res;
}

func pack_nibbles{range_check_ptr}(self: Nibbles*, bytes: felt*) -> felt {
alloc_locals;
let (local bytes_len, r) = unsigned_div_rem(self.nibbles_len, 2);
with_attr error_message("nibbles_len must be even") {
assert r = 0;
}
local range_check_ptr = range_check_ptr;

if (self.nibbles_len == 0) {
return 0;
}

tempvar count = 0;

body:
tempvar count = [ap - 1];
let self = cast([fp - 4], Nibbles*);
let bytes = cast([fp - 3], felt*);
let nib_index = 2 * count;
let nib_high = self.nibbles[nib_index];
let nib_low = self.nibbles[nib_index + 1];

let res = nib_high * 0x10 + nib_low;
assert bytes[count] = res;

let count = count + 1;
let is_done = Helpers.is_zero(self.nibbles_len - (nib_index + 2));

tempvar count = count;
jmp done if is_done != 0;
jmp body;

done:
let bytes = cast([fp - 3], felt*);
let bytes_len = [fp];
let range_check_ptr = [fp + 1];
return bytes_len;
}

func pack_with_prefix{range_check_ptr}(self: Nibbles*, is_leaf: felt) -> (
bytes_len: felt, bytes: felt*
) {
alloc_locals;
let (encoded) = alloc();
let (_, is_odd) = unsigned_div_rem(self.nibbles_len, 2);

// Case odd number of nibbles
if (is_odd != 0) {
let prefix = ((2 * is_leaf) + 1) * 16 + self.nibbles[0];
assert encoded[0] = prefix;
tempvar to_pack = new Nibbles(self.nibbles_len - 1, self.nibbles + 1);
let bytes_len = NibblesImpl.pack_nibbles(to_pack, encoded + 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let bytes_len = NibblesImpl.pack_nibbles(to_pack, encoded + 1);
let bytes_len = pack_nibbles(to_pack, encoded + 1);

let total_len = bytes_len + 1;
return (total_len, encoded);
}

// Case even number of nibbles
let prefix = 2 * is_leaf * 16;
assert encoded[0] = prefix;
let bytes_len = NibblesImpl.pack_nibbles(self, encoded + 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let bytes_len = NibblesImpl.pack_nibbles(self, encoded + 1);
let bytes_len = pack_nibbles(self, encoded + 1);

let total_len = bytes_len + 1;
return (total_len, encoded);
}
}
176 changes: 176 additions & 0 deletions src/utils/mpt/nodes.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.common.alloc import alloc

from utils.mpt.nibbles import Nibbles, NibblesImpl

namespace Node {
const LEAF = 0;
const EXTENSION = 1;
const BRANCH = 2;
}

struct Bytes {
data_len: felt,
data: felt*,
}

struct LeafNode {
type: felt,
key: Nibbles*,
value: Bytes,
}

namespace LeafNodeImpl {
const EVEN_FLAG = 0x20;
const ODD_FLAG = 0x30;

func init(key: Nibbles*, value: Bytes) -> LeafNode* {
tempvar res = new LeafNode(Node.LEAF, key, value);
return res;
}

// TODO: keccak(rlp_encode(key)) if len(rlp_encoded) > 32
func encode{range_check_ptr}(self: LeafNode*) -> Bytes* {
alloc_locals;
let (path_key_len, local unencoded) = NibblesImpl.pack_with_prefix(self.key, 1);
memcpy(dst=unencoded + path_key_len, src=self.value.data, len=self.value.data_len);

let (output) = alloc();
%{
import rlp
from ethereum.crypto.hash import keccak256
unencoded = serde.serialize_list(ids.unencoded, list_len=ids.path_key_len +ids.self.value.data_len)
encoded = rlp.encode(unencoded)

if len(encoded) < 32:
segments.write_arg(ids.output, unencoded)
else:
segments.write_arg(ids.output, keccak256(encoded))
%}

tempvar leaf_encoding = new Bytes(data_len=path_key_len + self.value.data_len, data=output);

// TODO: rlp encoding of value if required
return leaf_encoding;
}
}

struct ExtensionNode {
type: felt,
key: Nibbles*,
child: Bytes*,
}

namespace ExtensionNodeImpl {
const EVEN_FLAG = 0x00;
const ODD_FLAG = 0x10;

func init(key: Nibbles*, child: Bytes*) -> ExtensionNode* {
tempvar res = new ExtensionNode(type=Node.EXTENSION, key=key, child=child);
return res;
}

// TODO: keccak(rlp_encode(key)) if len(rlp_encoded) > 32
func encode{range_check_ptr}(self: ExtensionNode*) -> Bytes* {
alloc_locals;
let (local path_key_len, local unencoded) = NibblesImpl.pack_with_prefix(self.key, 0);

assert unencoded[path_key_len] = self.child.data_len;
assert unencoded[path_key_len + 1] = cast(self.child.data, felt);

let (output) = alloc();
tempvar output_len;
%{
import rlp
from ethereum.crypto.hash import keccak256
key_bytes = bytes(serde.serialize_list(ids.unencoded, list_len=ids.path_key_len))
child_len = ids.self.child.data_len
child_data = ids.self.child.data
if child_len == 17: # branch node
child_list = serde.serialize_list(child_data, item_scope="Bytes", list_len=child_len)
child_bytes = [bytes(elem) for elem in child_list]
else:
child_bytes = bytes(serde.serialize_list(child_data, list_len=child_len))
unencoded = [key_bytes, child_bytes]
encoded = rlp.encode(unencoded)

if len(encoded) < 32:
segments.write_arg(ids.output, unencoded)
ids.output_len = len(unencoded)
else:
hashed_rlp = keccak256(encoded)
segments.write_arg(ids.output, hashed_rlp)
ids.output_len = len(hashed_rlp)
%}

tempvar extension_encoding = new Bytes(data_len=output_len, data=output);

return extension_encoding;
}
}

struct BranchNode {
type: felt,
children_len: felt,
children: Bytes*,
value: Bytes,
}

namespace BranchNodeImpl {
func init(children: Bytes*, value: Bytes) -> BranchNode* {
tempvar res = new BranchNode(
type=Node.BRANCH, children_len=16, children=children, value=value
);
return res;
}

func encode{range_check_ptr}(self: BranchNode*) -> Bytes* {
alloc_locals;
let (unencoded) = alloc();
let subnodes_len = 16;

memcpy(dst=unencoded, src=self.children, len=self.children_len * Bytes.SIZE);
assert unencoded[subnodes_len * Bytes.SIZE] = self.value.data_len;
assert unencoded[subnodes_len * Bytes.SIZE + 1] = cast(self.value.data, felt);

let (branch_encoding: Bytes*) = alloc();
%{
import rlp
from ethereum.crypto.hash import keccak256
unencoded = [bytes(elem) for elem in serde.serialize_list(ids.unencoded, item_scope="Bytes", list_len=17*2)]
encoded = rlp.encode(unencoded)

if len(encoded) < 32:
bytes_data_len, bytes_data = serde.deserialize_bytes_list(unencoded)
ids.branch_encoding.data_len = bytes_data_len
ids.branch_encoding.data = bytes_data
else:
hashed_rlp = keccak256(encoded)
bytes_data_len, bytes_data = serde.deserialize_bytes(hashed_rlp)
ids.branch_encoding.data_len = bytes_data_len
ids.branch_encoding.data = bytes_data
%}

return branch_encoding;
}

func _encode_child(self: BranchNode*, index: felt, unencoded: felt*) {
if (index == 16) {
return ();
}

let child = self.children[index];
let child_data_len = child.data_len;
let child_data = child.data;

if (child_data_len == 0) {
let subnode = alloc();
assert unencoded[index] = subnode;
return _encode_child(self, index + 1, unencoded);
}

assert unencoded[index] = child_data;

return _encode_child(self, index + 1, unencoded);
}
}
Loading
Loading