Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search in linked lists using a BTree #603

Merged
merged 10 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ insert_new_account:
/// Returns 0 if the account was not found or `original_ptr` if it was already present.
global search_account:
// stack: addr_key, retdest
PROVER_INPUT(linked_list::insert_account)
PROVER_INPUT(linked_list::search_account)
// stack: pred_ptr/4, addr_key, retdest
%get_valid_account_ptr
// stack: pred_ptr, addr_key, retdest
Expand Down Expand Up @@ -685,7 +685,7 @@ next_node_ok:
/// Returns `value` if the storage key was inserted, `old_value` if it was already present.
global search_slot:
// stack: addr_key, key, value, retdest
PROVER_INPUT(linked_list::insert_slot)
PROVER_INPUT(linked_list::search_slot)
// stack: pred_ptr/5, addr_key, key, value, retdest
%get_valid_slot_ptr

Expand Down
16 changes: 13 additions & 3 deletions evm_arithmetization/src/cpu/kernel/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//! the future execution and generate nondeterministically the corresponding
//! jumpdest table, before the actual CPU carries on with contract execution.

use std::collections::{BTreeSet, HashMap};
use std::collections::{BTreeMap, BTreeSet, HashMap};

use anyhow::anyhow;
use ethereum_types::{BigEndianHash, U256};
Expand Down Expand Up @@ -115,6 +115,8 @@ pub(crate) struct ExtraSegmentData {
pub(crate) ger_prover_inputs: Vec<U256>,
pub(crate) trie_root_ptrs: TrieRootPtrs,
pub(crate) jumpdest_table: Option<HashMap<usize, Vec<usize>>>,
pub(crate) accounts: BTreeMap<U256, usize>,
pub(crate) storage: BTreeMap<(U256, U256), usize>,
pub(crate) next_txn_index: usize,
}

Expand Down Expand Up @@ -232,8 +234,12 @@ impl<F: RichField> Interpreter<F> {

// Initialize the MPT's pointers.
let (trie_root_ptrs, state_leaves, storage_leaves, trie_data) =
load_linked_lists_and_txn_and_receipt_mpts(&inputs.tries)
.expect("Invalid MPT data for preinitialization");
load_linked_lists_and_txn_and_receipt_mpts(
&mut self.generation_state.accounts_pointers,
&mut self.generation_state.storage_pointers,
&inputs.tries,
)
.expect("Invalid MPT data for preinitialization");

let trie_roots_after = &inputs.trie_roots_after;
self.generation_state.trie_root_ptrs = trie_root_ptrs;
Expand All @@ -253,6 +259,10 @@ impl<F: RichField> Interpreter<F> {
);
self.insert_preinitialized_segment(Segment::StorageLinkedList, preinit_storage_ll_segment);

// Initialize the accounts and storage BTrees.
self.generation_state.insert_all_slots_in_memory();
self.generation_state.insert_all_accounts_in_memory();

// Update the RLP and withdrawal prover inputs.
let rlp_prover_inputs = all_rlp_prover_inputs_reversed(&inputs.signed_txns);
let withdrawal_prover_inputs = all_withdrawals_prover_inputs_reversed(&inputs.withdrawals);
Expand Down
11 changes: 9 additions & 2 deletions evm_arithmetization/src/cpu/kernel/tests/account_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ pub(crate) fn initialize_mpts<F: RichField>(
) {
// Load all MPTs.
let (mut trie_root_ptrs, state_leaves, storage_leaves, trie_data) =
load_linked_lists_and_txn_and_receipt_mpts(trie_inputs)
.expect("Invalid MPT data for preinitialization");
load_linked_lists_and_txn_and_receipt_mpts(
&mut interpreter.generation_state.accounts_pointers,
&mut interpreter.generation_state.storage_pointers,
trie_inputs,
)
.expect("Invalid MPT data for preinitialization");

interpreter.generation_state.memory.contexts[0].segments
[Segment::AccountsLinkedList.unscale()]
Expand All @@ -44,6 +48,9 @@ pub(crate) fn initialize_mpts<F: RichField>(
trie_data.clone();
interpreter.generation_state.trie_root_ptrs = trie_root_ptrs.clone();

interpreter.generation_state.insert_all_slots_in_memory();
interpreter.generation_state.insert_all_accounts_in_memory();

if trie_root_ptrs.state_root_ptr.is_none() {
trie_root_ptrs.state_root_ptr = Some(
load_state_mpt(
Expand Down
76 changes: 56 additions & 20 deletions evm_arithmetization/src/cpu/kernel/tests/mpt/linked_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ use rand::{thread_rng, Rng};
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::cpu::kernel::interpreter::Interpreter;
use crate::generation::linked_list::LinkedList;
use crate::generation::linked_list::AccountsLinkedList;
use crate::generation::linked_list::StorageLinkedList;
use crate::memory::segments::Segment;
use crate::witness::memory::MemoryAddress;
use crate::witness::memory::MemorySegmentState;

fn init_logger() {
let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug"));
Expand Down Expand Up @@ -80,7 +82,8 @@ fn test_list_iterator() -> Result<()> {
.memory
.get_preinit_memory(Segment::AccountsLinkedList);
let mut accounts_list =
LinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList).unwrap();
AccountsLinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList)
.unwrap();

let Some([addr, ptr, ptr_cpy, scaled_pos_1]) = accounts_list.next() else {
return Err(anyhow::Error::msg("Couldn't get value"));
Expand All @@ -102,7 +105,7 @@ fn test_list_iterator() -> Result<()> {
.memory
.get_preinit_memory(Segment::StorageLinkedList);
let mut storage_list =
LinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap();
StorageLinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap();
let Some([addr, key, ptr, ptr_cpy, scaled_pos_1]) = storage_list.next() else {
return Err(anyhow::Error::msg("Couldn't get value"));
};
Expand Down Expand Up @@ -171,8 +174,16 @@ fn test_insert_account() -> Result<()> {
.memory
.get_preinit_memory(Segment::AccountsLinkedList);
let mut list =
LinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList).unwrap();
AccountsLinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList)
.unwrap();

let Some([addr, ptr, ptr_cpy, _]) = list.next() else {
return Err(anyhow::Error::msg("Couldn't get value"));
};
// This is the dummy node
assert_eq!(addr, U256::MAX);
assert_eq!(ptr, U256::zero());
assert_eq!(ptr_cpy, U256::zero());
let Some([addr, ptr, ptr_cpy, scaled_next_pos]) = list.next() else {
return Err(anyhow::Error::msg("Couldn't get value"));
};
Expand Down Expand Up @@ -251,7 +262,16 @@ fn test_insert_storage() -> Result<()> {
.memory
.get_preinit_memory(Segment::StorageLinkedList);
let mut list =
LinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap();
StorageLinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap();

let Some([inserted_addr, inserted_key, ptr, ptr_cpy, _]) = list.next() else {
return Err(anyhow::Error::msg("Couldn't get value"));
};
// This is the dummy node.
assert_eq!(inserted_addr, U256::MAX);
assert_eq!(inserted_key, U256::zero());
assert_eq!(ptr, U256::zero());
assert_eq!(ptr_cpy, U256::zero());

let Some([inserted_addr, inserted_key, ptr, ptr_cpy, scaled_next_pos]) = list.next() else {
return Err(anyhow::Error::msg("Couldn't get value"));
Expand Down Expand Up @@ -292,9 +312,17 @@ fn test_insert_and_delete_accounts() -> Result<()> {
Some((Segment::AccountsLinkedList as usize).into()),
];
let init_len = init_accounts_ll.len();
interpreter.generation_state.memory.contexts[0].segments
[Segment::AccountsLinkedList.unscale()]
.content = init_accounts_ll;

interpreter
.generation_state
.memory
.insert_preinitialized_segment(
Segment::AccountsLinkedList,
MemorySegmentState {
content: init_accounts_ll,
},
);

interpreter.set_global_metadata_field(
GlobalMetadata::AccountsLinkedListNextAvailable,
(Segment::AccountsLinkedList as usize + init_len).into(),
Expand Down Expand Up @@ -433,19 +461,22 @@ fn test_insert_and_delete_accounts() -> Result<()> {
.generation_state
.memory
.get_preinit_memory(Segment::AccountsLinkedList);
let list =
LinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList).unwrap();
let list = AccountsLinkedList::from_mem_and_segment(&accounts_mem, Segment::AccountsLinkedList)
.unwrap();

for (i, [addr, ptr, ptr_cpy, _]) in list.enumerate() {
if addr == U256::MAX {
assert_eq!(addr, U256::MAX);
assert_eq!(ptr, U256::zero());
assert_eq!(ptr_cpy, U256::zero());
break;
if i > 0 {
break;
}
} else {
let addr_in_list = U256::from(new_addresses[i - 1].0.as_slice());
assert_eq!(addr, addr_in_list);
assert_eq!(ptr, addr + delta_ptr);
}
let addr_in_list = U256::from(new_addresses[i].0.as_slice());
assert_eq!(addr, addr_in_list);
assert_eq!(ptr, addr + delta_ptr);
}

Ok(())
Expand Down Expand Up @@ -640,20 +671,25 @@ fn test_insert_and_delete_storage() -> Result<()> {
.generation_state
.memory
.get_preinit_memory(Segment::StorageLinkedList);
let list = LinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap();
let list =
StorageLinkedList::from_mem_and_segment(&accounts_mem, Segment::StorageLinkedList).unwrap();

for (i, [addr, key, ptr, ptr_cpy, _]) in list.enumerate() {
if addr == U256::MAX {
assert_eq!(addr, U256::MAX);
assert_eq!(key, U256::zero());
assert_eq!(ptr, U256::zero());
assert_eq!(ptr_cpy, U256::zero());
break;
if i > 0 {
break;
}
} else {
let [addr_in_list, key_in_list] =
new_addresses[i - 1].map(|x| U256::from(x.0.as_slice()));
assert_eq!(addr, addr_in_list);
assert_eq!(key, key_in_list);
assert_eq!(ptr, addr + delta_ptr);
}
let [addr_in_list, key_in_list] = new_addresses[i].map(|x| U256::from(x.0.as_slice()));
assert_eq!(addr, addr_in_list);
assert_eq!(key, key_in_list);
assert_eq!(ptr, addr + delta_ptr);
}

Ok(())
Expand Down
75 changes: 64 additions & 11 deletions evm_arithmetization/src/generation/linked_list.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt;
use std::marker::PhantomData;

use anyhow::Result;
use ethereum_types::U256;
Expand All @@ -8,15 +9,37 @@ use crate::util::u256_to_usize;
use crate::witness::errors::ProgramError;
use crate::witness::errors::ProverInputError::InvalidInput;

pub const ACCOUNTS_LINKED_LIST_NODE_SIZE: usize = 4;
pub const STORAGE_LINKED_LIST_NODE_SIZE: usize = 5;

pub(crate) trait LinkedListType {}
#[derive(Clone)]
/// A linked list that starts from the first node after the special node and
/// iterates forever.
pub(crate) struct Cyclic;
#[derive(Clone)]
/// A linked list that starts from the special node and iterates until the last
/// node.
pub(crate) struct Bounded;
impl LinkedListType for Cyclic {}
impl LinkedListType for Bounded {}

pub(crate) type AccountsLinkedList<'a> = LinkedList<'a, ACCOUNTS_LINKED_LIST_NODE_SIZE>;
pub(crate) type StorageLinkedList<'a> = LinkedList<'a, STORAGE_LINKED_LIST_NODE_SIZE>;

// A linked list implemented using a vector `access_list_mem`.
// In this representation, the values of nodes are stored in the range
// `access_list_mem[i..i + node_size - 1]`, and `access_list_mem[i + node_size -
// 1]` holds the address of the next node, where i = node_size * j.
#[derive(Clone)]
pub(crate) struct LinkedList<'a, const N: usize> {
pub(crate) struct LinkedList<'a, const N: usize, T = Cyclic>
where
T: LinkedListType,
{
mem: &'a [Option<U256>],
offset: usize,
pos: usize,
_marker: PhantomData<T>,
}

pub(crate) fn empty_list_mem<const N: usize>(segment: Segment) -> [Option<U256>; N] {
Expand All @@ -31,15 +54,15 @@ pub(crate) fn empty_list_mem<const N: usize>(segment: Segment) -> [Option<U256>;
})
}

impl<'a, const N: usize> LinkedList<'a, N> {
pub const fn from_mem_and_segment(
impl<'a, const N: usize, T: LinkedListType> LinkedList<'a, N, T> {
pub fn from_mem_and_segment(
Nashtare marked this conversation as resolved.
Show resolved Hide resolved
mem: &'a [Option<U256>],
segment: Segment,
) -> Result<Self, ProgramError> {
Self::from_mem_len_and_segment(mem, segment)
}

pub const fn from_mem_len_and_segment(
pub fn from_mem_len_and_segment(
mem: &'a [Option<U256>],
segment: Segment,
) -> Result<Self, ProgramError> {
Expand All @@ -50,6 +73,7 @@ impl<'a, const N: usize> LinkedList<'a, N> {
mem,
offset: segment as usize,
pos: 0,
_marker: PhantomData,
})
}
}
Expand All @@ -58,9 +82,8 @@ impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Linked List {{")?;
let cloned_list = self.clone();
for node in cloned_list {
if node[0] == U256::MAX {
writeln!(f, "{:?}", node)?;
for (i, node) in cloned_list.enumerate() {
if i > 0 && node[0] == U256::MAX {
break;
}
writeln!(f, "{:?} ->", node)?;
Expand All @@ -69,17 +92,47 @@ impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N> {
}
}

impl<'a, const N: usize> fmt::Debug for LinkedList<'a, N, Bounded> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Linked List {{")?;
let cloned_list = self.clone();
for node in cloned_list {
writeln!(f, "{:?} ->", node)?;
}
write!(f, "}}")
}
}

impl<'a, const N: usize> Iterator for LinkedList<'a, N> {
type Item = [U256; N];

fn next(&mut self) -> Option<Self::Item> {
// The first node is always the special node, so we skip it in the first
// iteration.
let node = Some(std::array::from_fn(|i| {
self.mem[self.pos + i].unwrap_or_default()
}));
if let Ok(new_pos) = u256_to_usize(self.mem[self.pos + N - 1].unwrap_or_default()) {
self.pos = new_pos - self.offset;
Some(std::array::from_fn(|i| {
node
} else {
None
}
}
}

impl<'a, const N: usize> Iterator for LinkedList<'a, N, Bounded> {
type Item = [U256; N];

fn next(&mut self) -> Option<Self::Item> {
if self.mem[self.pos] != Some(U256::MAX) {
let node = Some(std::array::from_fn(|i| {
self.mem[self.pos + i].unwrap_or_default()
}))
}));
if let Ok(new_pos) = u256_to_usize(self.mem[self.pos + N - 1].unwrap_or_default()) {
self.pos = new_pos - self.offset;
node
} else {
None
}
} else {
None
}
Expand Down
Loading
Loading