From 770b508b6e815fe85a786773cf532ba7960162eb Mon Sep 17 00:00:00 2001 From: Robin Salen Date: Sat, 20 Jul 2024 10:57:15 -0400 Subject: [PATCH] feat: Implement Columns view for MemoryStark --- evm_arithmetization/src/memory/columns.rs | 108 +++++---- .../src/memory/memory_stark.rs | 223 ++++++++++-------- 2 files changed, 182 insertions(+), 149 deletions(-) diff --git a/evm_arithmetization/src/memory/columns.rs b/evm_arithmetization/src/memory/columns.rs index 28fc7943f..5af4125ff 100644 --- a/evm_arithmetization/src/memory/columns.rs +++ b/evm_arithmetization/src/memory/columns.rs @@ -1,50 +1,66 @@ //! Memory registers. -use crate::memory::VALUE_LIMBS; - -// Columns for memory operations, ordered by (addr, timestamp). -/// 1 if this is an actual memory operation, or 0 if it's a padding row. -pub(crate) const FILTER: usize = 0; -/// Each memory operation is associated to a unique timestamp. -/// For a given memory operation `op_i`, its timestamp is computed as `C * N + -/// i` where `C` is the CPU clock at that time, `N` is the number of general -/// memory channels, and `i` is the index of the memory channel at which the -/// memory operation is performed. -pub(crate) const TIMESTAMP: usize = FILTER + 1; -/// 1 if this is a read operation, 0 if it is a write one. -pub(crate) const IS_READ: usize = TIMESTAMP + 1; -/// The execution context of this address. -pub(crate) const ADDR_CONTEXT: usize = IS_READ + 1; -/// The segment section of this address. -pub(crate) const ADDR_SEGMENT: usize = ADDR_CONTEXT + 1; -/// The virtual address within the given context and segment. -pub(crate) const ADDR_VIRTUAL: usize = ADDR_SEGMENT + 1; - -// Eight 32-bit limbs hold a total of 256 bits. -// If a value represents an integer, it is little-endian encoded. -const VALUE_START: usize = ADDR_VIRTUAL + 1; -pub(crate) const fn value_limb(i: usize) -> usize { - debug_assert!(i < VALUE_LIMBS); - VALUE_START + i +use std::mem::transmute; + +use zk_evm_proc_macro::{Columns, DerefColumns}; + +use crate::{memory::VALUE_LIMBS, util::indices_arr}; + +/// Columns for the `MemoryStark`. +#[repr(C)] +#[derive(Columns, DerefColumns, Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) struct MemoryColumnsView { + // Columns for memory operations, ordered by (addr, timestamp). + /// 1 if this is an actual memory operation, or 0 if it's a padding row. + pub filter: T, + /// Each memory operation is associated to a unique timestamp. + /// For a given memory operation `op_i`, its timestamp is computed as `C * N + /// + i` where `C` is the CPU clock at that time, `N` is the number of + /// general memory channels, and `i` is the index of the memory channel + /// at which the memory operation is performed. + pub timestamp: T, + /// 1 if this is a read operation, 0 if it is a write one. + pub is_read: T, + /// The execution context of this address. + pub addr_context: T, + /// The segment section of this address. + pub addr_segment: T, + /// The virtual address within the given context and segment. + pub addr_virtual: T, + + // Eight 32-bit limbs hold a total of 256 bits. + // If a value represents an integer, it is little-endian encoded. + pub value_limbs: [T; VALUE_LIMBS], + + // Flags to indicate whether this part of the address differs from the next row, + // and the previous parts do not differ. + // That is, e.g., `SEGMENT_FIRST_CHANGE` is `F::ONE` iff `ADDR_CONTEXT` is the + // same in this row and the next, but `ADDR_SEGMENT` is not. + pub context_first_change: T, + pub segment_first_change: T, + pub virtual_first_change: T, + + // Used to lower the degree of the zero-initializing constraints. + // Contains `next_segment * addr_changed * next_is_read`. + pub initialize_aux: T, + + // We use a range check to enforce the ordering. + pub range_check: T, + /// The counter column (used for the range check) starts from 0 and + /// increments. + pub counter: T, + /// The frequencies column used in logUp. + pub frequencies: T, } -// Flags to indicate whether this part of the address differs from the next row, -// and the previous parts do not differ. -// That is, e.g., `SEGMENT_FIRST_CHANGE` is `F::ONE` iff `ADDR_CONTEXT` is the -// same in this row and the next, but `ADDR_SEGMENT` is not. -pub(crate) const CONTEXT_FIRST_CHANGE: usize = VALUE_START + VALUE_LIMBS; -pub(crate) const SEGMENT_FIRST_CHANGE: usize = CONTEXT_FIRST_CHANGE + 1; -pub(crate) const VIRTUAL_FIRST_CHANGE: usize = SEGMENT_FIRST_CHANGE + 1; - -// Used to lower the degree of the zero-initializing constraints. -// Contains `next_segment * addr_changed * next_is_read`. -pub(crate) const INITIALIZE_AUX: usize = VIRTUAL_FIRST_CHANGE + 1; - -// We use a range check to enforce the ordering. -pub(crate) const RANGE_CHECK: usize = INITIALIZE_AUX + 1; -/// The counter column (used for the range check) starts from 0 and increments. -pub(crate) const COUNTER: usize = RANGE_CHECK + 1; -/// The frequencies column used in logUp. -pub(crate) const FREQUENCIES: usize = COUNTER + 1; - -pub(crate) const NUM_COLUMNS: usize = FREQUENCIES + 1; +/// Total number of columns in `MemoryStark`. +/// `u8` is guaranteed to have a `size_of` of 1. +pub(crate) const NUM_COLUMNS: usize = core::mem::size_of::>(); + +/// Mapping between [0..NUM_COLUMNS-1] and the memory columns. +pub(crate) const MEMORY_COL_MAP: MemoryColumnsView = make_col_map(); + +const fn make_col_map() -> MemoryColumnsView { + let indices_arr = indices_arr::(); + unsafe { transmute::<[usize; NUM_COLUMNS], MemoryColumnsView>(indices_arr) } +} diff --git a/evm_arithmetization/src/memory/memory_stark.rs b/evm_arithmetization/src/memory/memory_stark.rs index f788024fd..ba4f1255f 100644 --- a/evm_arithmetization/src/memory/memory_stark.rs +++ b/evm_arithmetization/src/memory/memory_stark.rs @@ -1,4 +1,5 @@ use core::marker::PhantomData; +use std::borrow::Borrow; use ethereum_types::U256; use itertools::Itertools; @@ -17,13 +18,10 @@ use starky::evaluation_frame::StarkEvaluationFrame; use starky::lookup::{Column, Filter, Lookup}; use starky::stark::Stark; +use super::columns::{MemoryColumnsView, MEMORY_COL_MAP}; use super::segments::Segment; use crate::all_stark::EvmStarkFrame; -use crate::memory::columns::{ - value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE, COUNTER, FILTER, - FREQUENCIES, INITIALIZE_AUX, IS_READ, NUM_COLUMNS, RANGE_CHECK, SEGMENT_FIRST_CHANGE, - TIMESTAMP, VIRTUAL_FIRST_CHANGE, -}; +use crate::memory::columns::NUM_COLUMNS; use crate::memory::VALUE_LIMBS; use crate::witness::memory::MemoryOpKind::Read; use crate::witness::memory::{MemoryAddress, MemoryOp}; @@ -34,16 +32,23 @@ use crate::witness::memory::{MemoryAddress, MemoryOp}; /// - the value being read/written, /// - the timestamp at which the element is read/written. pub(crate) fn ctl_data() -> Vec> { - let mut res = - Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec(); - res.extend(Column::singles((0..8).map(value_limb))); - res.push(Column::single(TIMESTAMP)); + let mut res = Column::singles([ + MEMORY_COL_MAP.is_read, + MEMORY_COL_MAP.addr_context, + MEMORY_COL_MAP.addr_segment, + MEMORY_COL_MAP.addr_virtual, + ]) + .collect_vec(); + res.extend(Column::singles( + (0..8).map(|i| MEMORY_COL_MAP.value_limbs[i]), + )); + res.push(Column::single(MEMORY_COL_MAP.timestamp)); res } /// CTL filter for memory operations. pub(crate) fn ctl_filter() -> Filter { - Filter::new_simple(Column::single(FILTER)) + Filter::new_simple(Column::single(MEMORY_COL_MAP.filter)) } #[derive(Copy, Clone, Default)] @@ -57,22 +62,23 @@ impl MemoryOp { /// `CONTEXT_FIRST_CHANGE`; those are generated later. It also does not /// generate columns such as `COUNTER`, which are generated later, after the /// trace has been transposed into column-major form. - fn into_row(self) -> [F; NUM_COLUMNS] { - let mut row = [F::ZERO; NUM_COLUMNS]; - row[FILTER] = F::from_bool(self.filter); - row[TIMESTAMP] = F::from_canonical_usize(self.timestamp); - row[IS_READ] = F::from_bool(self.kind == Read); + fn into_row(self) -> MemoryColumnsView { + let mut row = MemoryColumnsView::default(); + row.filter = F::from_bool(self.filter); + row.timestamp = F::from_canonical_usize(self.timestamp); + row.is_read = F::from_bool(self.kind == Read); let MemoryAddress { context, segment, virt, } = self.address; - row[ADDR_CONTEXT] = F::from_canonical_usize(context); - row[ADDR_SEGMENT] = F::from_canonical_usize(segment); - row[ADDR_VIRTUAL] = F::from_canonical_usize(virt); + row.addr_context = F::from_canonical_usize(context); + row.addr_segment = F::from_canonical_usize(segment); + row.addr_virtual = F::from_canonical_usize(virt); for j in 0..VALUE_LIMBS { - row[value_limb(j)] = F::from_canonical_u32((self.value >> (j * 32)).low_u32()); + row.value_limbs[j] = F::from_canonical_u32((self.value >> (j * 32)).low_u32()); } + row } } @@ -80,22 +86,22 @@ impl MemoryOp { /// Generates the `_FIRST_CHANGE` columns and the `RANGE_CHECK` column in the /// trace. pub(crate) fn generate_first_change_flags_and_rc( - trace_rows: &mut [[F; NUM_COLUMNS]], + trace_rows: &mut [MemoryColumnsView], ) { let num_ops = trace_rows.len(); for idx in 0..num_ops - 1 { - let row = trace_rows[idx].as_slice(); - let next_row = trace_rows[idx + 1].as_slice(); - - let context = row[ADDR_CONTEXT]; - let segment = row[ADDR_SEGMENT]; - let virt = row[ADDR_VIRTUAL]; - let timestamp = row[TIMESTAMP]; - let next_context = next_row[ADDR_CONTEXT]; - let next_segment = next_row[ADDR_SEGMENT]; - let next_virt = next_row[ADDR_VIRTUAL]; - let next_timestamp = next_row[TIMESTAMP]; - let next_is_read = next_row[IS_READ]; + let row = &trace_rows[idx]; + let next_row = &trace_rows[idx + 1]; + + let context = row.addr_context; + let segment = row.addr_segment; + let virt = row.addr_virtual; + let timestamp = row.timestamp; + let next_context = next_row.addr_context; + let next_segment = next_row.addr_segment; + let next_virt = next_row.addr_virtual; + let next_timestamp = next_row.timestamp; + let next_is_read = next_row.is_read; let context_changed = context != next_context; let segment_changed = segment != next_segment; @@ -106,12 +112,12 @@ pub(crate) fn generate_first_change_flags_and_rc( let virtual_first_change = virtual_changed && !segment_first_change && !context_first_change; - let row = trace_rows[idx].as_mut_slice(); - row[CONTEXT_FIRST_CHANGE] = F::from_bool(context_first_change); - row[SEGMENT_FIRST_CHANGE] = F::from_bool(segment_first_change); - row[VIRTUAL_FIRST_CHANGE] = F::from_bool(virtual_first_change); + let row = &mut trace_rows[idx]; + row.context_first_change = F::from_bool(context_first_change); + row.segment_first_change = F::from_bool(segment_first_change); + row.virtual_first_change = F::from_bool(virtual_first_change); - row[RANGE_CHECK] = if context_first_change { + row.range_check = if context_first_change { next_context - context - F::ONE } else if segment_first_change { next_segment - segment - F::ONE @@ -122,21 +128,21 @@ pub(crate) fn generate_first_change_flags_and_rc( }; assert!( - row[RANGE_CHECK].to_canonical_u64() < num_ops as u64, + row.range_check.to_canonical_u64() < num_ops as u64, "Range check of {} is too large. Bug in fill_gaps?", - row[RANGE_CHECK] + row.range_check ); let address_changed = - row[CONTEXT_FIRST_CHANGE] + row[SEGMENT_FIRST_CHANGE] + row[VIRTUAL_FIRST_CHANGE]; - row[INITIALIZE_AUX] = next_segment * address_changed * next_is_read; + row.context_first_change + row.segment_first_change + row.virtual_first_change; + row.initialize_aux = next_segment * address_changed * next_is_read; } } impl, const D: usize> MemoryStark { /// Generate most of the trace rows. Excludes a few columns like `COUNTER`, /// which are generated later, after transposing to column-major form. - fn generate_trace_row_major(&self, mut memory_ops: Vec) -> Vec<[F; NUM_COLUMNS]> { + fn generate_trace_row_major(&self, mut memory_ops: Vec) -> Vec> { // fill_gaps expects an ordered list of operations. memory_ops.sort_by_key(MemoryOp::sorting_key); Self::fill_gaps(&mut memory_ops); @@ -151,7 +157,7 @@ impl, const D: usize> MemoryStark { .into_par_iter() .map(|op| op.into_row()) .collect::>(); - generate_first_change_flags_and_rc(trace_rows.as_mut_slice()); + generate_first_change_flags_and_rc(&mut trace_rows); trace_rows } @@ -159,18 +165,20 @@ impl, const D: usize> MemoryStark { /// a trace in column-major form. fn generate_trace_col_major(trace_col_vecs: &mut [Vec]) { let height = trace_col_vecs[0].len(); - trace_col_vecs[COUNTER] = (0..height).map(|i| F::from_canonical_usize(i)).collect(); + trace_col_vecs[MEMORY_COL_MAP.counter] = + (0..height).map(|i| F::from_canonical_usize(i)).collect(); for i in 0..height { - let x_rc = trace_col_vecs[RANGE_CHECK][i].to_canonical_u64() as usize; - trace_col_vecs[FREQUENCIES][x_rc] += F::ONE; - if (trace_col_vecs[CONTEXT_FIRST_CHANGE][i] == F::ONE) - || (trace_col_vecs[SEGMENT_FIRST_CHANGE][i] == F::ONE) + let x_rc = trace_col_vecs[MEMORY_COL_MAP.range_check][i].to_canonical_u64() as usize; + trace_col_vecs[MEMORY_COL_MAP.frequencies][x_rc] += F::ONE; + if (trace_col_vecs[MEMORY_COL_MAP.context_first_change][i] == F::ONE) + || (trace_col_vecs[MEMORY_COL_MAP.segment_first_change][i] == F::ONE) { // CONTEXT_FIRST_CHANGE and SEGMENT_FIRST_CHANGE should be 0 at the last row, so // the index should never be out of bounds. - let x_fo = trace_col_vecs[ADDR_VIRTUAL][i + 1].to_canonical_u64() as usize; - trace_col_vecs[FREQUENCIES][x_fo] += F::ONE; + let x_fo = + trace_col_vecs[MEMORY_COL_MAP.addr_virtual][i + 1].to_canonical_u64() as usize; + trace_col_vecs[MEMORY_COL_MAP.frequencies][x_fo] += F::ONE; } } } @@ -292,24 +300,27 @@ impl, const D: usize> Stark for MemoryStark, { let one = P::from(FE::ONE); - let local_values = vars.get_local_values(); - let next_values = vars.get_next_values(); - - let timestamp = local_values[TIMESTAMP]; - let addr_context = local_values[ADDR_CONTEXT]; - let addr_segment = local_values[ADDR_SEGMENT]; - let addr_virtual = local_values[ADDR_VIRTUAL]; - let value_limbs: Vec<_> = (0..8).map(|i| local_values[value_limb(i)]).collect(); - - let next_timestamp = next_values[TIMESTAMP]; - let next_is_read = next_values[IS_READ]; - let next_addr_context = next_values[ADDR_CONTEXT]; - let next_addr_segment = next_values[ADDR_SEGMENT]; - let next_addr_virtual = next_values[ADDR_VIRTUAL]; - let next_values_limbs: Vec<_> = (0..8).map(|i| next_values[value_limb(i)]).collect(); + + let lv: &[P; NUM_COLUMNS] = vars.get_local_values().try_into().unwrap(); + let lv: &MemoryColumnsView

= lv.borrow(); + let nv: &[P; NUM_COLUMNS] = vars.get_next_values().try_into().unwrap(); + let nv: &MemoryColumnsView

= nv.borrow(); + + let timestamp = lv.timestamp; + let addr_context = lv.addr_context; + let addr_segment = lv.addr_segment; + let addr_virtual = lv.addr_virtual; + let value_limbs: Vec<_> = (0..8).map(|i| lv.value_limbs[i]).collect(); + + let next_timestamp = nv.timestamp; + let next_is_read = nv.is_read; + let next_addr_context = nv.addr_context; + let next_addr_segment = nv.addr_segment; + let next_addr_virtual = nv.addr_virtual; + let next_values_limbs: Vec<_> = (0..8).map(|i| nv.value_limbs[i]).collect(); // The filter must be 0 or 1. - let filter = local_values[FILTER]; + let filter = lv.filter; yield_constr.constraint(filter * (filter - P::ONES)); // IS_READ must be 0 or 1. @@ -320,16 +331,16 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, ) { let one = builder.one_extension(); - let local_values = vars.get_local_values(); - let next_values = vars.get_next_values(); - - let addr_context = local_values[ADDR_CONTEXT]; - let addr_segment = local_values[ADDR_SEGMENT]; - let addr_virtual = local_values[ADDR_VIRTUAL]; - let value_limbs: Vec<_> = (0..8).map(|i| local_values[value_limb(i)]).collect(); - let timestamp = local_values[TIMESTAMP]; - - let next_addr_context = next_values[ADDR_CONTEXT]; - let next_addr_segment = next_values[ADDR_SEGMENT]; - let next_addr_virtual = next_values[ADDR_VIRTUAL]; - let next_values_limbs: Vec<_> = (0..8).map(|i| next_values[value_limb(i)]).collect(); - let next_is_read = next_values[IS_READ]; - let next_timestamp = next_values[TIMESTAMP]; + + let lv: &[ExtensionTarget; NUM_COLUMNS] = vars.get_local_values().try_into().unwrap(); + let lv: &MemoryColumnsView> = lv.borrow(); + let nv: &[ExtensionTarget; NUM_COLUMNS] = vars.get_next_values().try_into().unwrap(); + let nv: &MemoryColumnsView> = nv.borrow(); + + let addr_context = lv.addr_context; + let addr_segment = lv.addr_segment; + let addr_virtual = lv.addr_virtual; + let value_limbs: Vec<_> = (0..8).map(|i| lv.value_limbs[i]).collect(); + let timestamp = lv.timestamp; + + let next_addr_context = nv.addr_context; + let next_addr_segment = nv.addr_segment; + let next_addr_virtual = nv.addr_virtual; + let next_values_limbs: Vec<_> = (0..8).map(|i| nv.value_limbs[i]).collect(); + let next_is_read = nv.is_read; + let next_timestamp = nv.timestamp; // The filter must be 0 or 1. - let filter = local_values[FILTER]; + let filter = lv.filter; let constraint = builder.mul_sub_extension(filter, filter, filter); yield_constr.constraint(builder, constraint); @@ -436,20 +450,20 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark Vec> { vec![Lookup { columns: vec![ - Column::single(RANGE_CHECK), - Column::single_next_row(ADDR_VIRTUAL), + Column::single(MEMORY_COL_MAP.range_check), + Column::single_next_row(MEMORY_COL_MAP.addr_virtual), ], - table_column: Column::single(COUNTER), - frequencies_column: Column::single(FREQUENCIES), + table_column: Column::single(MEMORY_COL_MAP.counter), + frequencies_column: Column::single(MEMORY_COL_MAP.frequencies), filter_columns: vec![ Default::default(), - Filter::new_simple(Column::sum([CONTEXT_FIRST_CHANGE, SEGMENT_FIRST_CHANGE])), + Filter::new_simple(Column::sum([ + MEMORY_COL_MAP.context_first_change, + MEMORY_COL_MAP.segment_first_change, + ])), ], }] }