From 82e6c20fbd31c9b2eb602309335009be631127f5 Mon Sep 17 00:00:00 2001 From: Ohad Nir <141617878+ohad-nir-starkware@users.noreply.github.com> Date: Mon, 20 Jan 2025 16:20:11 +0200 Subject: [PATCH] fill memory segments of builtins to the next power of 2 instances (#340) --- .../prover/src/input/builtin_segments.rs | 180 ++++++++++++++- .../crates/prover/src/input/memory.rs | 17 +- .../crates/prover/src/input/vm_import/mod.rs | 215 +++++++++--------- 3 files changed, 301 insertions(+), 111 deletions(-) diff --git a/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs b/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs index 488cd72e..ebce19bd 100644 --- a/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs +++ b/stwo_cairo_prover/crates/prover/src/input/builtin_segments.rs @@ -3,6 +3,9 @@ use cairo_vm::stdlib::collections::HashMap; use cairo_vm::types::builtin_name::BuiltinName; use serde::{Deserialize, Serialize}; +use super::memory::MemoryBuilder; + +// TODO(ohadn): change field types in MemorySegmentAddresses to match address type. /// This struct holds the builtins used in a Cairo program. #[derive(Debug, Default, Serialize, Deserialize)] pub struct BuiltinSegments { @@ -19,11 +22,9 @@ pub struct BuiltinSegments { } impl BuiltinSegments { - pub fn add_segment( - &mut self, - builtin_name: BuiltinName, - segment: Option, - ) { + /// Sets a segment in the builtin segments. + /// If a segment already exists for the given builtin name, it will be overwritten. + fn set_segment(&mut self, builtin_name: BuiltinName, segment: Option) { match builtin_name { BuiltinName::range_check => self.range_check_bits_128 = segment, BuiltinName::pedersen => self.pedersen = segment, @@ -40,6 +41,85 @@ impl BuiltinSegments { } } + // TODO(ohadn): change return type to non reference once MemorySegmentAddresses implements + // clone. + // TODO(ohadn): change output type to match address type. + /// Returns the segment for a given builtin name. + fn get_segment(&self, builtin_name: BuiltinName) -> &Option { + match builtin_name { + BuiltinName::range_check => &self.range_check_bits_128, + BuiltinName::pedersen => &self.pedersen, + BuiltinName::ecdsa => &self.ecdsa, + BuiltinName::keccak => &self.keccak, + BuiltinName::bitwise => &self.bitwise, + BuiltinName::ec_op => &self.ec_op, + BuiltinName::poseidon => &self.poseidon, + BuiltinName::range_check96 => &self.range_check_bits_96, + BuiltinName::add_mod => &self.add_mod, + BuiltinName::mul_mod => &self.mul_mod, + // Not builtins. + BuiltinName::output | BuiltinName::segment_arena => &None, + } + } + + /// Returns the number of memory cells per instance for a given builtin name. + pub fn builtin_memory_cells_per_instance(builtin_name: BuiltinName) -> usize { + match builtin_name { + BuiltinName::range_check => 1, + BuiltinName::pedersen => 3, + BuiltinName::ecdsa => 2, + BuiltinName::keccak => 16, + BuiltinName::bitwise => 5, + BuiltinName::ec_op => 7, + BuiltinName::poseidon => 6, + BuiltinName::range_check96 => 1, + BuiltinName::add_mod => 7, + BuiltinName::mul_mod => 7, + // Not builtins. + BuiltinName::output | BuiltinName::segment_arena => 0, + } + } + + /// Pads a builtin segment with copies of its last instance if that segment isn't None, in + /// which case at least one instance is guaranteed to exist. + /// The segment is padded to the next power of 2 number of instances. + /// Note: the last instance was already verified as valid by the VM and in the case of add_mod + /// and mul_mod, security checks have verified that instance has n=1. Thus the padded segment + /// satisfies all the AIR constraints. + // TODO (ohadn): relocate this function if a more appropriate place is found. + pub fn fill_builtin_segment(&mut self, memory: &mut MemoryBuilder, builtin_name: BuiltinName) { + let &Some(MemorySegmentAddresses { + begin_addr, + stop_ptr, + }) = self.get_segment(builtin_name) + else { + return; + }; + let initial_length = stop_ptr - begin_addr; + assert!(initial_length > 0); + let cells_per_instance = Self::builtin_memory_cells_per_instance(builtin_name); + assert!(initial_length % cells_per_instance == 0); + let num_instances = initial_length / cells_per_instance; + let next_power_of_two = num_instances.next_power_of_two(); + let mut instance_to_fill_start = stop_ptr as u32; + let last_instance_start = (stop_ptr - cells_per_instance) as u32; + for _ in num_instances..next_power_of_two { + memory.copy_block( + last_instance_start, + instance_to_fill_start, + cells_per_instance as u32, + ); + instance_to_fill_start += cells_per_instance as u32; + } + self.set_segment( + builtin_name, + Some(MemorySegmentAddresses { + begin_addr, + stop_ptr: begin_addr + cells_per_instance * next_power_of_two, + }), + ); + } + /// Creates a new `BuiltinSegments` struct from a map of memory segment names to addresses. pub fn from_memory_segments(memory_segments: &HashMap<&str, MemorySegmentAddresses>) -> Self { let mut res = BuiltinSegments::default(); @@ -57,7 +137,7 @@ impl BuiltinSegments { ); Some((value.begin_addr, value.stop_ptr).into()) }; - res.add_segment(builtin_name, segment); + res.set_segment(builtin_name, segment); }; } res @@ -69,10 +149,25 @@ impl BuiltinSegments { mod test_builtin_segments { use std::path::PathBuf; - use cairo_vm::air_public_input::PublicInput; + use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput}; + use cairo_vm::types::builtin_name::BuiltinName; + use crate::input::memory::{u128_to_4_limbs, Memory, MemoryBuilder, MemoryConfig, MemoryValue}; use crate::input::BuiltinSegments; + /// Asserts that the values at addresses start_addr1 to start_addr1 + segment_length - 1 + /// are equal to values at the addresses start_addr2 to start_addr2 + segment_length - 1. + pub fn assert_identical_blocks( + memory: &Memory, + start_addr1: u32, + start_addr2: u32, + segment_length: u32, + ) { + for i in 0..segment_length { + assert_eq!(memory.get(start_addr1 + i), memory.get(start_addr2 + i)); + } + } + #[test] fn test_builtin_segments() { let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) @@ -97,4 +192,75 @@ mod test_builtin_segments { Some((7069, 7187).into()) ); } + + /// Initializes a memory builder with the given u128 values. + /// Places the value instance_example[i] at the address memory_write_start + i. + fn initialize_memory(memory_write_start: u64, instance_example: &[u128]) -> MemoryBuilder { + let memory_config = MemoryConfig::default(); + let mut memory_builder = MemoryBuilder::new(memory_config.clone()); + for (i, &value) in instance_example.iter().enumerate() { + let memory_value = if value <= memory_config.small_max { + MemoryValue::Small(value) + } else { + let x = u128_to_4_limbs(value); + MemoryValue::F252([x[0], x[1], x[2], x[3], 0, 0, 0, 0]) + }; + memory_builder.set(memory_write_start + i as u64, memory_value); + } + memory_builder + } + + #[test] + fn test_fill_builtin_segment() { + let builtin_name = BuiltinName::bitwise; + let instance_example = [ + 123456789, + 4385067362534966725237889432551, + 50448645, + 4385067362534966725237911992050, + 4385067362534966725237962440695, + ]; + let mut builtin_segments = BuiltinSegments::default(); + let cells_per_instance = BuiltinSegments::builtin_memory_cells_per_instance(builtin_name); + assert_eq!(cells_per_instance, instance_example.len()); + let num_instances = 71; + let begin_addr = 23581; + let stop_ptr = begin_addr + cells_per_instance * num_instances; + builtin_segments.set_segment( + builtin_name, + Some(MemorySegmentAddresses { + begin_addr, + stop_ptr, + }), + ); + let memory_write_start = (stop_ptr - cells_per_instance) as u64; + let mut memory_builder = initialize_memory(memory_write_start, &instance_example); + + builtin_segments.fill_builtin_segment(&mut memory_builder, builtin_name); + + let &MemorySegmentAddresses { + begin_addr: new_begin_addr, + stop_ptr: new_stop_ptr, + } = builtin_segments.get_segment(builtin_name).as_ref().unwrap(); + assert_eq!(new_begin_addr, begin_addr); + let segment_length = new_stop_ptr - new_begin_addr; + assert_eq!(segment_length % cells_per_instance, 0); + let new_num_instances = segment_length / cells_per_instance; + assert_eq!(new_num_instances, 128); + + let memory = memory_builder.build(); + assert_eq!(memory.address_to_id.len(), new_stop_ptr); + + let mut instance_to_verify_start = stop_ptr as u32; + let last_instance_start = (stop_ptr - cells_per_instance) as u32; + for _ in num_instances..new_num_instances { + assert_identical_blocks( + &memory, + last_instance_start, + instance_to_verify_start, + cells_per_instance as u32, + ); + instance_to_verify_start += cells_per_instance as u32; + } + } } diff --git a/stwo_cairo_prover/crates/prover/src/input/memory.rs b/stwo_cairo_prover/crates/prover/src/input/memory.rs index 5d814b23..ad733f18 100644 --- a/stwo_cairo_prover/crates/prover/src/input/memory.rs +++ b/stwo_cairo_prover/crates/prover/src/input/memory.rs @@ -30,7 +30,7 @@ pub const P_MIN_2: [u32; 8] = [ 0x0800_0000, ]; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct MemoryConfig { pub small_max: u128, } @@ -99,6 +99,7 @@ impl Memory { } } +// TODO(ohadn): derive or impl a default for MemoryBuilder. pub struct MemoryBuilder { memory: Memory, felt252_id_cache: HashMap<[u32; 8], usize>, @@ -143,6 +144,7 @@ impl MemoryBuilder { res } + // TODO(ohadn): settle on an address integer type, and use it consistently. pub fn set(&mut self, addr: u64, value: MemoryValue) { if addr as usize >= self.address_to_id.len() { self.address_to_id @@ -168,6 +170,19 @@ impl MemoryBuilder { }); self.address_to_id[addr as usize] = res; } + + /// Copies a block of memory from one location to another. + /// The values at addresses src_start_addr to src_start_addr + segment_length - 1 are copied to + /// the addresses dst_start_addr to dst_start_addr + segment_length - 1. + pub fn copy_block(&mut self, src_start_addr: u32, dst_start_addr: u32, segment_length: u32) { + for i in 0..segment_length { + self.set( + (dst_start_addr + i) as u64, + self.memory.get(src_start_addr + i), + ); + } + } + pub fn build(self) -> Memory { self.memory } diff --git a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs index 25ad6379..bec62de7 100644 --- a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs @@ -6,6 +6,7 @@ use std::path::Path; use bytemuck::{bytes_of_mut, Pod, Zeroable}; use cairo_vm::air_public_input::{MemorySegmentAddresses, PublicInput}; use cairo_vm::stdlib::collections::HashMap; +use cairo_vm::types::builtin_name::BuiltinName; use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry; use json::PrivateInput; use thiserror::Error; @@ -88,12 +89,15 @@ pub fn adapt_to_stwo_input( ) -> Result { let (state_transitions, instruction_by_pc) = StateTransitions::from_iter(trace_iter, &mut memory, dev_mode); + let mut builtins_segments = BuiltinSegments::from_memory_segments(memory_segments); + // TODO (ohadn): fill in the memory segments of the rest of the builtins. + builtins_segments.fill_builtin_segment(&mut memory, BuiltinName::range_check); Ok(ProverInput { state_transitions, instruction_by_pc, memory: memory.build(), public_memory_addresses, - builtins_segments: BuiltinSegments::from_memory_segments(memory_segments), + builtins_segments, }) } @@ -187,113 +191,118 @@ pub mod tests { ) } - #[test] + #[cfg(test)] #[cfg(feature = "slow-tests")] - fn test_read_from_large_files() { - let input = large_cairo_input(); + pub mod slow_tests { - // Test opcode components. - let components = input.state_transitions.casm_states_by_opcode; - assert_eq!(components.generic_opcode.len(), 0); - assert_eq!(components.add_ap_opcode.len(), 0); - assert_eq!(components.add_ap_opcode_imm.len(), 36895); - assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 33); - assert_eq!(components.add_opcode_small_imm.len(), 84732); - assert_eq!(components.add_opcode.len(), 189425); - assert_eq!(components.add_opcode_small.len(), 36623); - assert_eq!(components.add_opcode_imm.len(), 22089); - assert_eq!(components.assert_eq_opcode.len(), 233432); - assert_eq!(components.assert_eq_opcode_double_deref.len(), 811061); - assert_eq!(components.assert_eq_opcode_imm.len(), 43184); - assert_eq!(components.call_opcode.len(), 0); - assert_eq!(components.call_opcode_rel.len(), 49439); - assert_eq!(components.call_opcode_op_1_base_fp.len(), 33); - assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 11235); - assert_eq!(components.jnz_opcode.len(), 27032); - assert_eq!(components.jnz_opcode_taken.len(), 51060); - assert_eq!(components.jnz_opcode_dst_base_fp.len(), 5100); - assert_eq!(components.jump_opcode_rel_imm.len(), 31873865); - assert_eq!(components.jump_opcode_rel.len(), 500); - assert_eq!(components.jump_opcode_double_deref.len(), 32); - assert_eq!(components.jump_opcode.len(), 0); - assert_eq!(components.mul_opcode_small_imm.len(), 7234); - assert_eq!(components.mul_opcode_small.len(), 7203); - assert_eq!(components.mul_opcode.len(), 3943); - assert_eq!(components.mul_opcode_imm.len(), 10809); - assert_eq!(components.ret_opcode.len(), 49472); + use super::*; - // Test builtins. - let builtins_segments = input.builtins_segments; - assert_eq!(builtins_segments.add_mod, None); - assert_eq!(builtins_segments.bitwise, None); - assert_eq!(builtins_segments.ec_op, Some((16428600, 16428747).into())); - assert_eq!(builtins_segments.ecdsa, None); - assert_eq!(builtins_segments.keccak, None); - assert_eq!(builtins_segments.mul_mod, None); - assert_eq!(builtins_segments.pedersen, Some((1322552, 1337489).into())); - assert_eq!( - builtins_segments.poseidon, - Some((16920120, 17444532).into()) - ); - assert_eq!(builtins_segments.range_check_bits_96, None); - assert_eq!( - builtins_segments.range_check_bits_128, - Some((1715768, 1757348).into()) - ); - } + #[test] + fn test_read_from_large_files() { + let input = large_cairo_input(); - #[cfg(feature = "slow-tests")] - #[test] - fn test_read_from_small_files() { - let input = small_cairo_input(); + // Test opcode components. + let components = input.state_transitions.casm_states_by_opcode; + assert_eq!(components.generic_opcode.len(), 0); + assert_eq!(components.add_ap_opcode.len(), 0); + assert_eq!(components.add_ap_opcode_imm.len(), 36895); + assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 33); + assert_eq!(components.add_opcode_small_imm.len(), 84732); + assert_eq!(components.add_opcode.len(), 189425); + assert_eq!(components.add_opcode_small.len(), 36623); + assert_eq!(components.add_opcode_imm.len(), 22089); + assert_eq!(components.assert_eq_opcode.len(), 233432); + assert_eq!(components.assert_eq_opcode_double_deref.len(), 811061); + assert_eq!(components.assert_eq_opcode_imm.len(), 43184); + assert_eq!(components.call_opcode.len(), 0); + assert_eq!(components.call_opcode_rel.len(), 49439); + assert_eq!(components.call_opcode_op_1_base_fp.len(), 33); + assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 11235); + assert_eq!(components.jnz_opcode.len(), 27032); + assert_eq!(components.jnz_opcode_taken.len(), 51060); + assert_eq!(components.jnz_opcode_dst_base_fp.len(), 5100); + assert_eq!(components.jump_opcode_rel_imm.len(), 31873865); + assert_eq!(components.jump_opcode_rel.len(), 500); + assert_eq!(components.jump_opcode_double_deref.len(), 32); + assert_eq!(components.jump_opcode.len(), 0); + assert_eq!(components.mul_opcode_small_imm.len(), 7234); + assert_eq!(components.mul_opcode_small.len(), 7203); + assert_eq!(components.mul_opcode.len(), 3943); + assert_eq!(components.mul_opcode_imm.len(), 10809); + assert_eq!(components.ret_opcode.len(), 49472); + + // Test builtins. + let builtins_segments = input.builtins_segments; + assert_eq!(builtins_segments.add_mod, None); + assert_eq!(builtins_segments.bitwise, None); + assert_eq!(builtins_segments.ec_op, Some((16428600, 16428747).into())); + assert_eq!(builtins_segments.ecdsa, None); + assert_eq!(builtins_segments.keccak, None); + assert_eq!(builtins_segments.mul_mod, None); + assert_eq!(builtins_segments.pedersen, Some((1322552, 1337489).into())); + assert_eq!( + builtins_segments.poseidon, + Some((16920120, 17444532).into()) + ); + assert_eq!(builtins_segments.range_check_bits_96, None); + assert_eq!( + builtins_segments.range_check_bits_128, + Some((1715768, 1781304).into()) + ); + } + + #[test] + fn test_read_from_small_files() { + let input = small_cairo_input(); - // Test opcode components. - let components = input.state_transitions.casm_states_by_opcode; - assert_eq!(components.generic_opcode.len(), 0); - assert_eq!(components.add_ap_opcode.len(), 0); - assert_eq!(components.add_ap_opcode_imm.len(), 2); - assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 1); - assert_eq!(components.add_opcode_small_imm.len(), 500); - assert_eq!(components.add_opcode.len(), 0); - assert_eq!(components.add_opcode_small.len(), 0); - assert_eq!(components.add_opcode_imm.len(), 450); - assert_eq!(components.assert_eq_opcode.len(), 55); - assert_eq!(components.assert_eq_opcode_double_deref.len(), 2100); - assert_eq!(components.assert_eq_opcode_imm.len(), 1952); - assert_eq!(components.call_opcode.len(), 0); - assert_eq!(components.call_opcode_rel.len(), 462); - assert_eq!(components.call_opcode_op_1_base_fp.len(), 0); - assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 450); - assert_eq!(components.jnz_opcode.len(), 0); - assert_eq!(components.jnz_opcode_taken.len(), 0); - assert_eq!(components.jnz_opcode_dst_base_fp.len(), 11); - assert_eq!(components.jump_opcode_rel_imm.len(), 124626); - assert_eq!(components.jump_opcode_rel.len(), 0); - assert_eq!(components.jump_opcode_double_deref.len(), 0); - assert_eq!(components.jump_opcode.len(), 0); - assert_eq!(components.mul_opcode_small_imm.len(), 0); - assert_eq!(components.mul_opcode_small.len(), 0); - assert_eq!(components.mul_opcode.len(), 0); - assert_eq!(components.mul_opcode_imm.len(), 0); - assert_eq!(components.ret_opcode.len(), 462); + // Test opcode components. + let components = input.state_transitions.casm_states_by_opcode; + assert_eq!(components.generic_opcode.len(), 0); + assert_eq!(components.add_ap_opcode.len(), 0); + assert_eq!(components.add_ap_opcode_imm.len(), 2); + assert_eq!(components.add_ap_opcode_op_1_base_fp.len(), 1); + assert_eq!(components.add_opcode_small_imm.len(), 500); + assert_eq!(components.add_opcode.len(), 0); + assert_eq!(components.add_opcode_small.len(), 0); + assert_eq!(components.add_opcode_imm.len(), 450); + assert_eq!(components.assert_eq_opcode.len(), 55); + assert_eq!(components.assert_eq_opcode_double_deref.len(), 2100); + assert_eq!(components.assert_eq_opcode_imm.len(), 1952); + assert_eq!(components.call_opcode.len(), 0); + assert_eq!(components.call_opcode_rel.len(), 462); + assert_eq!(components.call_opcode_op_1_base_fp.len(), 0); + assert_eq!(components.jnz_opcode_taken_dst_base_fp.len(), 450); + assert_eq!(components.jnz_opcode.len(), 0); + assert_eq!(components.jnz_opcode_taken.len(), 0); + assert_eq!(components.jnz_opcode_dst_base_fp.len(), 11); + assert_eq!(components.jump_opcode_rel_imm.len(), 124626); + assert_eq!(components.jump_opcode_rel.len(), 0); + assert_eq!(components.jump_opcode_double_deref.len(), 0); + assert_eq!(components.jump_opcode.len(), 0); + assert_eq!(components.mul_opcode_small_imm.len(), 0); + assert_eq!(components.mul_opcode_small.len(), 0); + assert_eq!(components.mul_opcode.len(), 0); + assert_eq!(components.mul_opcode_imm.len(), 0); + assert_eq!(components.ret_opcode.len(), 462); - // Test builtins. - let builtins_segments = input.builtins_segments; - assert_eq!(builtins_segments.add_mod, None); - assert_eq!(builtins_segments.bitwise, Some((22512, 22762).into())); - assert_eq!(builtins_segments.ec_op, Some((63472, 63822).into())); - assert_eq!(builtins_segments.ecdsa, Some((22384, 22484).into())); - assert_eq!(builtins_segments.keccak, Some((64368, 65168).into())); - assert_eq!(builtins_segments.mul_mod, None); - assert_eq!(builtins_segments.pedersen, Some((4464, 4614).into())); - assert_eq!(builtins_segments.poseidon, Some((65392, 65692).into())); - assert_eq!( - builtins_segments.range_check_bits_96, - Some((68464, 68514).into()) - ); - assert_eq!( - builtins_segments.range_check_bits_128, - Some((6000, 6050).into()) - ); + // Test builtins. + let builtins_segments = input.builtins_segments; + assert_eq!(builtins_segments.add_mod, None); + assert_eq!(builtins_segments.bitwise, Some((22512, 22762).into())); + assert_eq!(builtins_segments.ec_op, Some((63472, 63822).into())); + assert_eq!(builtins_segments.ecdsa, Some((22384, 22484).into())); + assert_eq!(builtins_segments.keccak, Some((64368, 65168).into())); + assert_eq!(builtins_segments.mul_mod, None); + assert_eq!(builtins_segments.pedersen, Some((4464, 4614).into())); + assert_eq!(builtins_segments.poseidon, Some((65392, 65692).into())); + assert_eq!( + builtins_segments.range_check_bits_96, + Some((68464, 68514).into()) + ); + assert_eq!( + builtins_segments.range_check_bits_128, + Some((6000, 6064).into()) + ); + } } }