Skip to content

Commit

Permalink
fill memory segments of builtins to the next power of 2 instances (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-nir-starkware authored Jan 20, 2025
1 parent 76b9d66 commit 82e6c20
Show file tree
Hide file tree
Showing 3 changed files with 301 additions and 111 deletions.
180 changes: 173 additions & 7 deletions stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use cairo_vm::stdlib::collections::HashMap;
use cairo_vm::types::builtin_name::BuiltinName;
use serde::{Deserialize, Serialize};

use super::memory::MemoryBuilder;

// TODO(ohadn): change field types in MemorySegmentAddresses to match address type.
/// This struct holds the builtins used in a Cairo program.
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct BuiltinSegments {
Expand All @@ -19,11 +22,9 @@ pub struct BuiltinSegments {
}

impl BuiltinSegments {
pub fn add_segment(
&mut self,
builtin_name: BuiltinName,
segment: Option<MemorySegmentAddresses>,
) {
/// Sets a segment in the builtin segments.
/// If a segment already exists for the given builtin name, it will be overwritten.
fn set_segment(&mut self, builtin_name: BuiltinName, segment: Option<MemorySegmentAddresses>) {
match builtin_name {
BuiltinName::range_check => self.range_check_bits_128 = segment,
BuiltinName::pedersen => self.pedersen = segment,
Expand All @@ -40,6 +41,85 @@ impl BuiltinSegments {
}
}

// TODO(ohadn): change return type to non reference once MemorySegmentAddresses implements
// clone.
// TODO(ohadn): change output type to match address type.
/// Returns the segment for a given builtin name.
fn get_segment(&self, builtin_name: BuiltinName) -> &Option<MemorySegmentAddresses> {
match builtin_name {
BuiltinName::range_check => &self.range_check_bits_128,
BuiltinName::pedersen => &self.pedersen,
BuiltinName::ecdsa => &self.ecdsa,
BuiltinName::keccak => &self.keccak,
BuiltinName::bitwise => &self.bitwise,
BuiltinName::ec_op => &self.ec_op,
BuiltinName::poseidon => &self.poseidon,
BuiltinName::range_check96 => &self.range_check_bits_96,
BuiltinName::add_mod => &self.add_mod,
BuiltinName::mul_mod => &self.mul_mod,
// Not builtins.
BuiltinName::output | BuiltinName::segment_arena => &None,
}
}

/// Returns the number of memory cells per instance for a given builtin name.
pub fn builtin_memory_cells_per_instance(builtin_name: BuiltinName) -> usize {
match builtin_name {
BuiltinName::range_check => 1,
BuiltinName::pedersen => 3,
BuiltinName::ecdsa => 2,
BuiltinName::keccak => 16,
BuiltinName::bitwise => 5,
BuiltinName::ec_op => 7,
BuiltinName::poseidon => 6,
BuiltinName::range_check96 => 1,
BuiltinName::add_mod => 7,
BuiltinName::mul_mod => 7,
// Not builtins.
BuiltinName::output | BuiltinName::segment_arena => 0,
}
}

/// Pads a builtin segment with copies of its last instance if that segment isn't None, in
/// which case at least one instance is guaranteed to exist.
/// The segment is padded to the next power of 2 number of instances.
/// Note: the last instance was already verified as valid by the VM and in the case of add_mod
/// and mul_mod, security checks have verified that instance has n=1. Thus the padded segment
/// satisfies all the AIR constraints.
// TODO (ohadn): relocate this function if a more appropriate place is found.
pub fn fill_builtin_segment(&mut self, memory: &mut MemoryBuilder, builtin_name: BuiltinName) {
let &Some(MemorySegmentAddresses {
begin_addr,
stop_ptr,
}) = self.get_segment(builtin_name)
else {
return;
};
let initial_length = stop_ptr - begin_addr;
assert!(initial_length > 0);
let cells_per_instance = Self::builtin_memory_cells_per_instance(builtin_name);
assert!(initial_length % cells_per_instance == 0);
let num_instances = initial_length / cells_per_instance;
let next_power_of_two = num_instances.next_power_of_two();
let mut instance_to_fill_start = stop_ptr as u32;
let last_instance_start = (stop_ptr - cells_per_instance) as u32;
for _ in num_instances..next_power_of_two {
memory.copy_block(
last_instance_start,
instance_to_fill_start,
cells_per_instance as u32,
);
instance_to_fill_start += cells_per_instance as u32;
}
self.set_segment(
builtin_name,
Some(MemorySegmentAddresses {
begin_addr,
stop_ptr: begin_addr + cells_per_instance * next_power_of_two,
}),
);
}

/// Creates a new `BuiltinSegments` struct from a map of memory segment names to addresses.
pub fn from_memory_segments(memory_segments: &HashMap<&str, MemorySegmentAddresses>) -> Self {
let mut res = BuiltinSegments::default();
Expand All @@ -57,7 +137,7 @@ impl BuiltinSegments {
);
Some((value.begin_addr, value.stop_ptr).into())
};
res.add_segment(builtin_name, segment);
res.set_segment(builtin_name, segment);
};
}
res
Expand All @@ -69,10 +149,25 @@ impl BuiltinSegments {
mod test_builtin_segments {
use std::path::PathBuf;

use cairo_vm::air_public_input::PublicInput;
use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput};
use cairo_vm::types::builtin_name::BuiltinName;

use crate::input::memory::{u128_to_4_limbs, Memory, MemoryBuilder, MemoryConfig, MemoryValue};
use crate::input::BuiltinSegments;

/// Asserts that the values at addresses start_addr1 to start_addr1 + segment_length - 1
/// are equal to values at the addresses start_addr2 to start_addr2 + segment_length - 1.
pub fn assert_identical_blocks(
memory: &Memory,
start_addr1: u32,
start_addr2: u32,
segment_length: u32,
) {
for i in 0..segment_length {
assert_eq!(memory.get(start_addr1 + i), memory.get(start_addr2 + i));
}
}

#[test]
fn test_builtin_segments() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
Expand All @@ -97,4 +192,75 @@ mod test_builtin_segments {
Some((7069, 7187).into())
);
}

/// Initializes a memory builder with the given u128 values.
/// Places the value instance_example[i] at the address memory_write_start + i.
fn initialize_memory(memory_write_start: u64, instance_example: &[u128]) -> MemoryBuilder {
let memory_config = MemoryConfig::default();
let mut memory_builder = MemoryBuilder::new(memory_config.clone());
for (i, &value) in instance_example.iter().enumerate() {
let memory_value = if value <= memory_config.small_max {
MemoryValue::Small(value)
} else {
let x = u128_to_4_limbs(value);
MemoryValue::F252([x[0], x[1], x[2], x[3], 0, 0, 0, 0])
};
memory_builder.set(memory_write_start + i as u64, memory_value);
}
memory_builder
}

#[test]
fn test_fill_builtin_segment() {
let builtin_name = BuiltinName::bitwise;
let instance_example = [
123456789,
4385067362534966725237889432551,
50448645,
4385067362534966725237911992050,
4385067362534966725237962440695,
];
let mut builtin_segments = BuiltinSegments::default();
let cells_per_instance = BuiltinSegments::builtin_memory_cells_per_instance(builtin_name);
assert_eq!(cells_per_instance, instance_example.len());
let num_instances = 71;
let begin_addr = 23581;
let stop_ptr = begin_addr + cells_per_instance * num_instances;
builtin_segments.set_segment(
builtin_name,
Some(MemorySegmentAddresses {
begin_addr,
stop_ptr,
}),
);
let memory_write_start = (stop_ptr - cells_per_instance) as u64;
let mut memory_builder = initialize_memory(memory_write_start, &instance_example);

builtin_segments.fill_builtin_segment(&mut memory_builder, builtin_name);

let &MemorySegmentAddresses {
begin_addr: new_begin_addr,
stop_ptr: new_stop_ptr,
} = builtin_segments.get_segment(builtin_name).as_ref().unwrap();
assert_eq!(new_begin_addr, begin_addr);
let segment_length = new_stop_ptr - new_begin_addr;
assert_eq!(segment_length % cells_per_instance, 0);
let new_num_instances = segment_length / cells_per_instance;
assert_eq!(new_num_instances, 128);

let memory = memory_builder.build();
assert_eq!(memory.address_to_id.len(), new_stop_ptr);

let mut instance_to_verify_start = stop_ptr as u32;
let last_instance_start = (stop_ptr - cells_per_instance) as u32;
for _ in num_instances..new_num_instances {
assert_identical_blocks(
&memory,
last_instance_start,
instance_to_verify_start,
cells_per_instance as u32,
);
instance_to_verify_start += cells_per_instance as u32;
}
}
}
17 changes: 16 additions & 1 deletion stwo_cairo_prover/crates/prover/src/input/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub const P_MIN_2: [u32; 8] = [
0x0800_0000,
];

#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct MemoryConfig {
pub small_max: u128,
}
Expand Down Expand Up @@ -99,6 +99,7 @@ impl Memory {
}
}

// TODO(ohadn): derive or impl a default for MemoryBuilder.
pub struct MemoryBuilder {
memory: Memory,
felt252_id_cache: HashMap<[u32; 8], usize>,
Expand Down Expand Up @@ -143,6 +144,7 @@ impl MemoryBuilder {
res
}

// TODO(ohadn): settle on an address integer type, and use it consistently.
pub fn set(&mut self, addr: u64, value: MemoryValue) {
if addr as usize >= self.address_to_id.len() {
self.address_to_id
Expand All @@ -168,6 +170,19 @@ impl MemoryBuilder {
});
self.address_to_id[addr as usize] = res;
}

/// Copies a block of memory from one location to another.
/// The values at addresses src_start_addr to src_start_addr + segment_length - 1 are copied to
/// the addresses dst_start_addr to dst_start_addr + segment_length - 1.
pub fn copy_block(&mut self, src_start_addr: u32, dst_start_addr: u32, segment_length: u32) {
for i in 0..segment_length {
self.set(
(dst_start_addr + i) as u64,
self.memory.get(src_start_addr + i),
);
}
}

pub fn build(self) -> Memory {
self.memory
}
Expand Down
Loading

0 comments on commit 82e6c20

Please sign in to comment.