Skip to content

Commit

Permalink
Do not propagate zero values (#484)
Browse files Browse the repository at this point in the history
* Do not propagate zero values

* Address comments
  • Loading branch information
hratoanina authored Aug 14, 2024
1 parent f47590f commit 74d499d
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 45 deletions.
23 changes: 14 additions & 9 deletions evm_arithmetization/src/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::generation::state::{GenerationState, State};
use crate::generation::trie_extractor::{get_receipt_trie, get_state_trie, get_txn_trie};
use crate::memory::segments::Segment;
use crate::memory::columns::PREINITIALIZED_SEGMENTS;
use crate::memory::segments::{Segment, PREINITIALIZED_SEGMENTS_INDICES};
use crate::proof::{
BlockHashes, BlockMetadata, ExtraBlockData, MemCap, PublicValues, RegistersData, TrieRoots,
};
Expand Down Expand Up @@ -380,20 +381,24 @@ fn initialize_kernel_code_and_shift_table(memory: &mut MemoryState) {

/// Returns the memory addresses and values that should comprise the state at
/// the start of the segment's execution.
/// Ignores zero values in non-preinitialized segments.
fn get_all_memory_address_and_values(memory_before: &MemoryState) -> Vec<(MemoryAddress, U256)> {
let mut res = vec![];
for (ctx_idx, ctx) in memory_before.contexts.iter().enumerate() {
for (segment_idx, segment) in ctx.segments.iter().enumerate() {
for (virt, value) in segment.content.iter().enumerate() {
if let &Some(val) = value {
res.push((
MemoryAddress {
context: ctx_idx,
segment: segment_idx,
virt,
},
val,
));
// We skip zero values in non-preinitialized segments.
if !val.is_zero() || PREINITIALIZED_SEGMENTS_INDICES.contains(&segment_idx) {
res.push((
MemoryAddress {
context: ctx_idx,
segment: segment_idx,
virt,
},
val,
));
}
}
}
}
Expand Down
9 changes: 7 additions & 2 deletions evm_arithmetization/src/memory/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,13 @@ pub(crate) const STALE_CONTEXTS_FREQUENCIES: usize = IS_PRUNED + 1;
// `ADDR_CONTEXT` + 1 is in `STALE_CONTEXTS`.
pub(crate) const IS_STALE: usize = STALE_CONTEXTS_FREQUENCIES + 1;

// Filter for the `MemAfter` CTL.
pub(crate) const MEM_AFTER_FILTER: usize = IS_STALE + 1;
// Flag indicating that a value can potentially be propagated.
// Contains `filter * address_changed * is_not_stale`.
pub(crate) const MAYBE_IN_MEM_AFTER: usize = IS_STALE + 1;

// Filter for the `MemAfter` CTL. Is equal to `MAYBE_IN_MEM_AFTER` if segment is
// preinitialized or the value is non-zero, is 0 otherwise.
pub(crate) const MEM_AFTER_FILTER: usize = MAYBE_IN_MEM_AFTER + 1;

// We use a range check to enforce the ordering.
pub(crate) const RANGE_CHECK: usize = MEM_AFTER_FILTER + 1;
Expand Down
103 changes: 75 additions & 28 deletions evm_arithmetization/src/memory/memory_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ use starky::lookup::{Column, Filter, Lookup};
use starky::stark::Stark;

use super::columns::{
MEM_AFTER_FILTER, PREINITIALIZED_SEGMENTS, PREINITIALIZED_SEGMENTS_AUX, STALE_CONTEXTS,
MAYBE_IN_MEM_AFTER, MEM_AFTER_FILTER, PREINITIALIZED_SEGMENTS, STALE_CONTEXTS,
STALE_CONTEXTS_FREQUENCIES,
};
use super::segments::Segment;
use super::segments::{Segment, PREINITIALIZED_SEGMENTS_INDICES};
use crate::all_stark::{EvmStarkFrame, Table};
use crate::memory::columns::{
value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE, COUNTER, FILTER,
FREQUENCIES, INITIALIZE_AUX, IS_PRUNED, IS_READ, IS_STALE, NUM_COLUMNS, RANGE_CHECK,
SEGMENT_FIRST_CHANGE, TIMESTAMP, TIMESTAMP_INV, VIRTUAL_FIRST_CHANGE,
FREQUENCIES, INITIALIZE_AUX, IS_PRUNED, IS_READ, IS_STALE, NUM_COLUMNS,
PREINITIALIZED_SEGMENTS_AUX, RANGE_CHECK, SEGMENT_FIRST_CHANGE, TIMESTAMP, TIMESTAMP_INV,
VIRTUAL_FIRST_CHANGE,
};
use crate::memory::VALUE_LIMBS;
use crate::witness::memory::MemoryOpKind::{self, Read};
Expand All @@ -42,7 +43,7 @@ use crate::witness::memory::{MemoryAddress, MemoryOp};
pub(crate) fn ctl_data<F: Field>() -> Vec<Column<F>> {
let mut res =
Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec();
res.extend(Column::singles((0..8).map(value_limb)));
res.extend(Column::singles((0..VALUE_LIMBS).map(value_limb)));
res.push(Column::single(TIMESTAMP));
res
}
Expand All @@ -57,7 +58,7 @@ pub(crate) fn ctl_filter<F: Field>() -> Filter<F> {
/// - the value in u32 limbs.
pub(crate) fn ctl_looking_mem<F: Field>() -> Vec<Column<F>> {
let mut res = Column::singles([ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec();
res.extend(Column::singles((0..8).map(value_limb)));
res.extend(Column::singles((0..VALUE_LIMBS).map(value_limb)));
res
}

Expand Down Expand Up @@ -108,11 +109,7 @@ impl MemoryOp {
let mut row = [F::ZERO; NUM_COLUMNS];
row[FILTER] = F::from_bool(self.filter);
row[TIMESTAMP] = F::from_canonical_usize(self.timestamp);
if self.timestamp != 0 {
row[TIMESTAMP_INV] = row[TIMESTAMP].inverse();
} else {
row[TIMESTAMP_INV] = F::ZERO;
}
row[TIMESTAMP_INV] = row[TIMESTAMP].try_inverse().unwrap_or_default();
row[IS_READ] = F::from_bool(self.kind == Read);
let MemoryAddress {
context,
Expand Down Expand Up @@ -231,8 +228,8 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {

/// Generates the `COUNTER`, `RANGE_CHECK` and `FREQUENCIES` columns, given
/// a trace in column-major form.
/// Also generates the `STALE_CONTEXTS`, `STALE_CONTEXTS_FREQUENCIES` and
/// `MEM_AFTER_FILTER` columns.
/// Also generates the `STALE_CONTEXTS`, `STALE_CONTEXTS_FREQUENCIES`,
/// `MAYBE_IN_MEM_AFTER` and `MEM_AFTER_FILTER` columns.
fn generate_trace_col_major(trace_col_vecs: &mut [Vec<F>]) {
let height = trace_col_vecs[0].len();
trace_col_vecs[COUNTER] = (0..height).map(|i| F::from_canonical_usize(i)).collect();
Expand Down Expand Up @@ -261,8 +258,19 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
|| trace_col_vecs[SEGMENT_FIRST_CHANGE][i].is_one()
|| trace_col_vecs[VIRTUAL_FIRST_CHANGE][i].is_one())
{
// `mem_after_filter = filter * address_changed * (1 - is_stale)`
trace_col_vecs[MEM_AFTER_FILTER][i] = F::ONE;
// `maybe_in_mem_after = filter * address_changed * (1 - is_stale)`
trace_col_vecs[MAYBE_IN_MEM_AFTER][i] = F::ONE;

let addr_segment = trace_col_vecs[ADDR_SEGMENT][i];
let is_non_zero_value =
(0..VALUE_LIMBS).any(|limb| trace_col_vecs[value_limb(limb)][i].is_nonzero());
// We filter out zero values in non-preinitialized segments.
if is_non_zero_value
|| PREINITIALIZED_SEGMENTS_INDICES
.contains(&(addr_segment.to_canonical_u64() as usize))
{
trace_col_vecs[MEM_AFTER_FILTER][i] = F::ONE;
}
}
}
}
Expand Down Expand Up @@ -472,14 +480,18 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let addr_context = local_values[ADDR_CONTEXT];
let addr_segment = local_values[ADDR_SEGMENT];
let addr_virtual = local_values[ADDR_VIRTUAL];
let value_limbs: Vec<_> = (0..8).map(|i| local_values[value_limb(i)]).collect();
let value_limbs: Vec<_> = (0..VALUE_LIMBS)
.map(|i| local_values[value_limb(i)])
.collect();

let next_timestamp = next_values[TIMESTAMP];
let next_is_read = next_values[IS_READ];
let next_addr_context = next_values[ADDR_CONTEXT];
let next_addr_segment = next_values[ADDR_SEGMENT];
let next_addr_virtual = next_values[ADDR_VIRTUAL];
let next_values_limbs: Vec<_> = (0..8).map(|i| next_values[value_limb(i)]).collect();
let next_values_limbs: Vec<_> = (0..VALUE_LIMBS)
.map(|i| next_values[value_limb(i)])
.collect();

// The filter must be 0 or 1.
let filter = local_values[FILTER];
Expand Down Expand Up @@ -561,7 +573,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
initialize_aux - preinitialized_segments * not_address_unchanged * next_is_read,
);

for i in 0..8 {
for i in 0..VALUE_LIMBS {
// Enumerate purportedly-ordered log.
yield_constr.constraint_transition(
next_is_read * address_unchanged * (next_values_limbs[i] - value_limbs[i]),
Expand All @@ -573,14 +585,27 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
yield_constr.constraint_transition(initialize_aux * next_values_limbs[i]);
}

// Validate `mem_after_filter`.
let mem_after_filter = local_values[MEM_AFTER_FILTER];
// Validate `maybe_in_mem_after`.
let maybe_in_mem_after = local_values[MAYBE_IN_MEM_AFTER];
let is_stale = local_values[IS_STALE];
yield_constr.constraint_transition(
mem_after_filter + filter * not_address_unchanged * (is_stale - P::ONES),
maybe_in_mem_after + filter * not_address_unchanged * (is_stale - P::ONES),
);

// Validate `timestamp_inv`. Since it's used as a CTL filter, its value must be
let mem_after_filter = local_values[MEM_AFTER_FILTER];
// `mem_after_filter` must be binary.
yield_constr.constraint(mem_after_filter * (mem_after_filter - P::ONES));

// `mem_after_filter` is equal to `maybe_in_mem_after` if:
// - segment is not preinitialized OR
// - value is not zero.
for i in 0..VALUE_LIMBS {
yield_constr.constraint(
(mem_after_filter - maybe_in_mem_after) * preinitialized_segments * value_limbs[i],
);
}

// Validate timestamp_inv. Since it's used as a CTL filter, its value must be
// checked.
let timestamp_inv = local_values[TIMESTAMP_INV];
yield_constr.constraint(timestamp * (timestamp * timestamp_inv - P::ONES));
Expand All @@ -607,13 +632,17 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let addr_context = local_values[ADDR_CONTEXT];
let addr_segment = local_values[ADDR_SEGMENT];
let addr_virtual = local_values[ADDR_VIRTUAL];
let value_limbs: Vec<_> = (0..8).map(|i| local_values[value_limb(i)]).collect();
let value_limbs: Vec<_> = (0..VALUE_LIMBS)
.map(|i| local_values[value_limb(i)])
.collect();
let timestamp = local_values[TIMESTAMP];

let next_addr_context = next_values[ADDR_CONTEXT];
let next_addr_segment = next_values[ADDR_SEGMENT];
let next_addr_virtual = next_values[ADDR_VIRTUAL];
let next_values_limbs: Vec<_> = (0..8).map(|i| next_values[value_limb(i)]).collect();
let next_values_limbs: Vec<_> = (0..VALUE_LIMBS)
.map(|i| next_values[value_limb(i)])
.collect();
let next_is_read = next_values[IS_READ];
let next_timestamp = next_values[TIMESTAMP];

Expand Down Expand Up @@ -758,7 +787,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
builder.sub_extension(initialize_aux, computed_initialize_aux);
yield_constr.constraint_transition(builder, new_first_read_constraint);

for i in 0..8 {
for i in 0..VALUE_LIMBS {
// Enumerate purportedly-ordered log.
let value_diff = builder.sub_extension(next_values_limbs[i], value_limbs[i]);
let zero_if_read = builder.mul_extension(address_unchanged, value_diff);
Expand All @@ -772,16 +801,34 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
yield_constr.constraint_transition(builder, zero_init_constraint);
}

// Validate `mem_after_filter`.
let mem_after_filter = local_values[MEM_AFTER_FILTER];
// Validate `maybe_in_mem_after`.
let maybe_in_mem_after = local_values[MAYBE_IN_MEM_AFTER];
let is_stale = local_values[IS_STALE];
{
let rhs = builder.mul_extension(filter, not_address_unchanged);
let rhs = builder.mul_sub_extension(rhs, is_stale, rhs);
let constr = builder.add_extension(mem_after_filter, rhs);
let constr = builder.add_extension(maybe_in_mem_after, rhs);
yield_constr.constraint_transition(builder, constr);
}

let mem_after_filter = local_values[MEM_AFTER_FILTER];
// `mem_after_filter` must be binary.
{
let constr =
builder.mul_sub_extension(mem_after_filter, mem_after_filter, mem_after_filter);
yield_constr.constraint(builder, constr);
}

// `mem_after_filter` is equal to `maybe_in_mem_after` if:
// - segment is not preinitialized OR
// - value is not zero.
let mem_after_filter_diff = builder.sub_extension(mem_after_filter, maybe_in_mem_after);
for i in 0..VALUE_LIMBS {
let prod = builder.mul_extension(preinitialized_segments, value_limbs[i]);
let constr = builder.mul_extension(mem_after_filter_diff, prod);
yield_constr.constraint(builder, constr);
}

// Validate timestamp_inv. Since it's used as a CTL filter, its value must be
// checked.
let timestamp_inv = local_values[TIMESTAMP_INV];
Expand Down
8 changes: 8 additions & 0 deletions evm_arithmetization/src/memory/segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ pub(crate) enum Segment {
StorageLinkedList = 35 << SEGMENT_SCALING_FACTOR,
}

// These segments are not zero-initialized.
pub(crate) const PREINITIALIZED_SEGMENTS_INDICES: [usize; 4] = [
Segment::Code.unscale(),
Segment::TrieData.unscale(),
Segment::AccountsLinkedList.unscale(),
Segment::StorageLinkedList.unscale(),
];

impl Segment {
pub(crate) const COUNT: usize = 36;

Expand Down
9 changes: 3 additions & 6 deletions evm_arithmetization/src/witness/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use MemoryChannel::{Code, GeneralPurpose, PartialChannel};

use super::operation::CONTEXT_SCALING_FACTOR;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::memory::segments::{Segment, SEGMENT_SCALING_FACTOR};
use crate::memory::segments::{Segment, PREINITIALIZED_SEGMENTS_INDICES, SEGMENT_SCALING_FACTOR};
use crate::witness::errors::MemoryError::{ContextTooLarge, SegmentTooLarge, VirtTooLarge};
use crate::witness::errors::ProgramError;
use crate::witness::errors::ProgramError::MemoryError;
Expand Down Expand Up @@ -225,11 +225,7 @@ impl MemoryState {
/// need a specific behaviour here, since the values can be stored either in
/// `preinitialized_segments` or in the memory itself.
pub(crate) fn get_preinit_memory(&self, segment: Segment) -> Vec<Option<U256>> {
assert!(
segment == Segment::AccountsLinkedList
|| segment == Segment::StorageLinkedList
|| segment == Segment::TrieData
);
assert!(PREINITIALIZED_SEGMENTS_INDICES.contains(&segment.unscale()));
let len = self
.preinitialized_segments
.get(&segment)
Expand Down Expand Up @@ -301,6 +297,7 @@ impl MemoryState {
segment: Segment,
values: MemorySegmentState,
) {
assert!(PREINITIALIZED_SEGMENTS_INDICES.contains(&segment.unscale()));
self.preinitialized_segments.insert(segment, values);
}

Expand Down

0 comments on commit 74d499d

Please sign in to comment.