Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not propagate zero values #484

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions evm_arithmetization/src/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::generation::state::{GenerationState, State};
use crate::generation::trie_extractor::{get_receipt_trie, get_state_trie, get_txn_trie};
use crate::memory::segments::Segment;
use crate::memory::columns::PREINITIALIZED_SEGMENTS;
use crate::memory::segments::{Segment, PREINITIALIZED_SEGMENTS_INDICES};
use crate::proof::{
BlockHashes, BlockMetadata, ExtraBlockData, MemCap, PublicValues, RegistersData, TrieRoots,
};
Expand Down Expand Up @@ -380,20 +381,24 @@ fn initialize_kernel_code_and_shift_table(memory: &mut MemoryState) {

/// Returns the memory addresses and values that should comprise the state at
/// the start of the segment's execution.
/// Ignores zero values in non-preinitialized segments.
fn get_all_memory_address_and_values(memory_before: &MemoryState) -> Vec<(MemoryAddress, U256)> {
let mut res = vec![];
for (ctx_idx, ctx) in memory_before.contexts.iter().enumerate() {
for (segment_idx, segment) in ctx.segments.iter().enumerate() {
for (virt, value) in segment.content.iter().enumerate() {
if let &Some(val) = value {
res.push((
MemoryAddress {
context: ctx_idx,
segment: segment_idx,
virt,
},
val,
));
// We skip zero values in non-preinitialized segments.
if !val.is_zero() || PREINITIALIZED_SEGMENTS_INDICES.contains(&segment_idx) {
res.push((
MemoryAddress {
context: ctx_idx,
segment: segment_idx,
virt,
},
val,
));
}
}
}
}
Expand Down
9 changes: 7 additions & 2 deletions evm_arithmetization/src/memory/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,13 @@ pub(crate) const STALE_CONTEXTS_FREQUENCIES: usize = IS_PRUNED + 1;
// `ADDR_CONTEXT` + 1 is in `STALE_CONTEXTS`.
pub(crate) const IS_STALE: usize = STALE_CONTEXTS_FREQUENCIES + 1;

// Filter for the `MemAfter` CTL.
pub(crate) const MEM_AFTER_FILTER: usize = IS_STALE + 1;
// Flag indicating that a value can potentially be propagated.
// Contains `filter * address_changed * is_not_stale`.
pub(crate) const MAYBE_IN_MEM_AFTER: usize = IS_STALE + 1;

// Filter for the `MemAfter` CTL. Is equal to `MAYBE_IN_MEM_AFTER` if segment is
// preinitialized or the value is non-zero, is 0 otherwise.
pub(crate) const MEM_AFTER_FILTER: usize = MAYBE_IN_MEM_AFTER + 1;

// We use a range check to enforce the ordering.
pub(crate) const RANGE_CHECK: usize = MEM_AFTER_FILTER + 1;
Expand Down
103 changes: 75 additions & 28 deletions evm_arithmetization/src/memory/memory_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ use starky::lookup::{Column, Filter, Lookup};
use starky::stark::Stark;

use super::columns::{
MEM_AFTER_FILTER, PREINITIALIZED_SEGMENTS, PREINITIALIZED_SEGMENTS_AUX, STALE_CONTEXTS,
MAYBE_IN_MEM_AFTER, MEM_AFTER_FILTER, PREINITIALIZED_SEGMENTS, STALE_CONTEXTS,
STALE_CONTEXTS_FREQUENCIES,
};
use super::segments::Segment;
use super::segments::{Segment, PREINITIALIZED_SEGMENTS_INDICES};
use crate::all_stark::{EvmStarkFrame, Table};
use crate::memory::columns::{
value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE, COUNTER, FILTER,
FREQUENCIES, INITIALIZE_AUX, IS_PRUNED, IS_READ, IS_STALE, NUM_COLUMNS, RANGE_CHECK,
SEGMENT_FIRST_CHANGE, TIMESTAMP, TIMESTAMP_INV, VIRTUAL_FIRST_CHANGE,
FREQUENCIES, INITIALIZE_AUX, IS_PRUNED, IS_READ, IS_STALE, NUM_COLUMNS,
PREINITIALIZED_SEGMENTS_AUX, RANGE_CHECK, SEGMENT_FIRST_CHANGE, TIMESTAMP, TIMESTAMP_INV,
VIRTUAL_FIRST_CHANGE,
};
use crate::memory::VALUE_LIMBS;
use crate::witness::memory::MemoryOpKind::{self, Read};
Expand All @@ -42,7 +43,7 @@ use crate::witness::memory::{MemoryAddress, MemoryOp};
pub(crate) fn ctl_data<F: Field>() -> Vec<Column<F>> {
let mut res =
Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec();
res.extend(Column::singles((0..8).map(value_limb)));
res.extend(Column::singles((0..VALUE_LIMBS).map(value_limb)));
res.push(Column::single(TIMESTAMP));
res
}
Expand All @@ -57,7 +58,7 @@ pub(crate) fn ctl_filter<F: Field>() -> Filter<F> {
/// - the value in u32 limbs.
pub(crate) fn ctl_looking_mem<F: Field>() -> Vec<Column<F>> {
let mut res = Column::singles([ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec();
res.extend(Column::singles((0..8).map(value_limb)));
res.extend(Column::singles((0..VALUE_LIMBS).map(value_limb)));
res
}

Expand Down Expand Up @@ -108,11 +109,7 @@ impl MemoryOp {
let mut row = [F::ZERO; NUM_COLUMNS];
row[FILTER] = F::from_bool(self.filter);
row[TIMESTAMP] = F::from_canonical_usize(self.timestamp);
if self.timestamp != 0 {
row[TIMESTAMP_INV] = row[TIMESTAMP].inverse();
} else {
row[TIMESTAMP_INV] = F::ZERO;
}
row[TIMESTAMP_INV] = row[TIMESTAMP].try_inverse().unwrap_or_default();
row[IS_READ] = F::from_bool(self.kind == Read);
let MemoryAddress {
context,
Expand Down Expand Up @@ -231,8 +228,8 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {

/// Generates the `COUNTER`, `RANGE_CHECK` and `FREQUENCIES` columns, given
/// a trace in column-major form.
/// Also generates the `STALE_CONTEXTS`, `STALE_CONTEXTS_FREQUENCIES` and
/// `MEM_AFTER_FILTER` columns.
/// Also generates the `STALE_CONTEXTS`, `STALE_CONTEXTS_FREQUENCIES`,
/// `MAYBE_IN_MEM_AFTER` and `MEM_AFTER_FILTER` columns.
fn generate_trace_col_major(trace_col_vecs: &mut [Vec<F>]) {
let height = trace_col_vecs[0].len();
trace_col_vecs[COUNTER] = (0..height).map(|i| F::from_canonical_usize(i)).collect();
Expand Down Expand Up @@ -261,8 +258,19 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
|| trace_col_vecs[SEGMENT_FIRST_CHANGE][i].is_one()
|| trace_col_vecs[VIRTUAL_FIRST_CHANGE][i].is_one())
{
// `mem_after_filter = filter * address_changed * (1 - is_stale)`
trace_col_vecs[MEM_AFTER_FILTER][i] = F::ONE;
// `maybe_in_mem_after = filter * address_changed * (1 - is_stale)`
trace_col_vecs[MAYBE_IN_MEM_AFTER][i] = F::ONE;

let addr_segment = trace_col_vecs[ADDR_SEGMENT][i];
let is_non_zero_value =
(0..VALUE_LIMBS).any(|limb| trace_col_vecs[value_limb(limb)][i].is_nonzero());
// We filter out zero values in non-preinitialized segments.
if is_non_zero_value
|| PREINITIALIZED_SEGMENTS_INDICES
.contains(&(addr_segment.to_canonical_u64() as usize))
{
trace_col_vecs[MEM_AFTER_FILTER][i] = F::ONE;
}
}
}
}
Expand Down Expand Up @@ -472,14 +480,18 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let addr_context = local_values[ADDR_CONTEXT];
let addr_segment = local_values[ADDR_SEGMENT];
let addr_virtual = local_values[ADDR_VIRTUAL];
let value_limbs: Vec<_> = (0..8).map(|i| local_values[value_limb(i)]).collect();
let value_limbs: Vec<_> = (0..VALUE_LIMBS)
.map(|i| local_values[value_limb(i)])
.collect();

let next_timestamp = next_values[TIMESTAMP];
let next_is_read = next_values[IS_READ];
let next_addr_context = next_values[ADDR_CONTEXT];
let next_addr_segment = next_values[ADDR_SEGMENT];
let next_addr_virtual = next_values[ADDR_VIRTUAL];
let next_values_limbs: Vec<_> = (0..8).map(|i| next_values[value_limb(i)]).collect();
let next_values_limbs: Vec<_> = (0..VALUE_LIMBS)
.map(|i| next_values[value_limb(i)])
.collect();

// The filter must be 0 or 1.
let filter = local_values[FILTER];
Expand Down Expand Up @@ -561,7 +573,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
initialize_aux - preinitialized_segments * not_address_unchanged * next_is_read,
);

for i in 0..8 {
for i in 0..VALUE_LIMBS {
// Enumerate purportedly-ordered log.
yield_constr.constraint_transition(
next_is_read * address_unchanged * (next_values_limbs[i] - value_limbs[i]),
Expand All @@ -573,14 +585,27 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
yield_constr.constraint_transition(initialize_aux * next_values_limbs[i]);
}

// Validate `mem_after_filter`.
let mem_after_filter = local_values[MEM_AFTER_FILTER];
// Validate `maybe_in_mem_after`.
let maybe_in_mem_after = local_values[MAYBE_IN_MEM_AFTER];
let is_stale = local_values[IS_STALE];
yield_constr.constraint_transition(
mem_after_filter + filter * not_address_unchanged * (is_stale - P::ONES),
maybe_in_mem_after + filter * not_address_unchanged * (is_stale - P::ONES),
);

// Validate `timestamp_inv`. Since it's used as a CTL filter, its value must be
let mem_after_filter = local_values[MEM_AFTER_FILTER];
// `mem_after_filter` must be binary.
yield_constr.constraint(mem_after_filter * (mem_after_filter - P::ONES));

// `mem_after_filter` is equal to `maybe_in_mem_after` if:
// - segment is not preinitialized OR
// - value is not zero.
for i in 0..VALUE_LIMBS {
yield_constr.constraint(
(mem_after_filter - maybe_in_mem_after) * preinitialized_segments * value_limbs[i],
);
}

// Validate timestamp_inv. Since it's used as a CTL filter, its value must be
// checked.
let timestamp_inv = local_values[TIMESTAMP_INV];
yield_constr.constraint(timestamp * (timestamp * timestamp_inv - P::ONES));
Expand All @@ -607,13 +632,17 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let addr_context = local_values[ADDR_CONTEXT];
let addr_segment = local_values[ADDR_SEGMENT];
let addr_virtual = local_values[ADDR_VIRTUAL];
let value_limbs: Vec<_> = (0..8).map(|i| local_values[value_limb(i)]).collect();
let value_limbs: Vec<_> = (0..VALUE_LIMBS)
.map(|i| local_values[value_limb(i)])
.collect();
let timestamp = local_values[TIMESTAMP];

let next_addr_context = next_values[ADDR_CONTEXT];
let next_addr_segment = next_values[ADDR_SEGMENT];
let next_addr_virtual = next_values[ADDR_VIRTUAL];
let next_values_limbs: Vec<_> = (0..8).map(|i| next_values[value_limb(i)]).collect();
let next_values_limbs: Vec<_> = (0..VALUE_LIMBS)
.map(|i| next_values[value_limb(i)])
.collect();
let next_is_read = next_values[IS_READ];
let next_timestamp = next_values[TIMESTAMP];

Expand Down Expand Up @@ -758,7 +787,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
builder.sub_extension(initialize_aux, computed_initialize_aux);
yield_constr.constraint_transition(builder, new_first_read_constraint);

for i in 0..8 {
for i in 0..VALUE_LIMBS {
// Enumerate purportedly-ordered log.
let value_diff = builder.sub_extension(next_values_limbs[i], value_limbs[i]);
let zero_if_read = builder.mul_extension(address_unchanged, value_diff);
Expand All @@ -772,16 +801,34 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
yield_constr.constraint_transition(builder, zero_init_constraint);
}

// Validate `mem_after_filter`.
let mem_after_filter = local_values[MEM_AFTER_FILTER];
// Validate `maybe_in_mem_after`.
let maybe_in_mem_after = local_values[MAYBE_IN_MEM_AFTER];
let is_stale = local_values[IS_STALE];
{
let rhs = builder.mul_extension(filter, not_address_unchanged);
let rhs = builder.mul_sub_extension(rhs, is_stale, rhs);
let constr = builder.add_extension(mem_after_filter, rhs);
let constr = builder.add_extension(maybe_in_mem_after, rhs);
yield_constr.constraint_transition(builder, constr);
}

let mem_after_filter = local_values[MEM_AFTER_FILTER];
// `mem_after_filter` must be binary.
{
let constr =
builder.mul_sub_extension(mem_after_filter, mem_after_filter, mem_after_filter);
yield_constr.constraint(builder, constr);
}

// `mem_after_filter` is equal to `maybe_in_mem_after` if:
// - segment is not preinitialized OR
// - value is not zero.
let mem_after_filter_diff = builder.sub_extension(mem_after_filter, maybe_in_mem_after);
for i in 0..VALUE_LIMBS {
let prod = builder.mul_extension(preinitialized_segments, value_limbs[i]);
let constr = builder.mul_extension(mem_after_filter_diff, prod);
yield_constr.constraint(builder, constr);
}

// Validate timestamp_inv. Since it's used as a CTL filter, its value must be
// checked.
let timestamp_inv = local_values[TIMESTAMP_INV];
Expand Down
8 changes: 8 additions & 0 deletions evm_arithmetization/src/memory/segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ pub(crate) enum Segment {
StorageLinkedList = 35 << SEGMENT_SCALING_FACTOR,
}

// These segments are not zero-initialized.
pub(crate) const PREINITIALIZED_SEGMENTS_INDICES: [usize; 4] = [
Nashtare marked this conversation as resolved.
Show resolved Hide resolved
Segment::Code.unscale(),
Segment::TrieData.unscale(),
Segment::AccountsLinkedList.unscale(),
Segment::StorageLinkedList.unscale(),
];

impl Segment {
pub(crate) const COUNT: usize = 36;

Expand Down
9 changes: 3 additions & 6 deletions evm_arithmetization/src/witness/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use MemoryChannel::{Code, GeneralPurpose, PartialChannel};

use super::operation::CONTEXT_SCALING_FACTOR;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::memory::segments::{Segment, SEGMENT_SCALING_FACTOR};
use crate::memory::segments::{Segment, PREINITIALIZED_SEGMENTS_INDICES, SEGMENT_SCALING_FACTOR};
use crate::witness::errors::MemoryError::{ContextTooLarge, SegmentTooLarge, VirtTooLarge};
use crate::witness::errors::ProgramError;
use crate::witness::errors::ProgramError::MemoryError;
Expand Down Expand Up @@ -225,11 +225,7 @@ impl MemoryState {
/// need a specific behaviour here, since the values can be stored either in
/// `preinitialized_segments` or in the memory itself.
pub(crate) fn get_preinit_memory(&self, segment: Segment) -> Vec<Option<U256>> {
assert!(
segment == Segment::AccountsLinkedList
|| segment == Segment::StorageLinkedList
|| segment == Segment::TrieData
);
assert!(PREINITIALIZED_SEGMENTS_INDICES.contains(&segment.unscale()));
let len = self
.preinitialized_segments
.get(&segment)
Expand Down Expand Up @@ -301,6 +297,7 @@ impl MemoryState {
segment: Segment,
values: MemorySegmentState,
) {
assert!(PREINITIALIZED_SEGMENTS_INDICES.contains(&segment.unscale()));
self.preinitialized_segments.insert(segment, values);
}

Expand Down
Loading