diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8242d3ca..cc5cfd28 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -135,7 +135,7 @@ jobs: with: repository: 0xPolygonHermez/pil2-proofman-js token: ${{ secrets.ZISK_CI_TOKEN }} - ref: 0.0.6 + ref: 0.0.7 path: pil2-proofman-js - name: Install pil2-proofman-js dependencies @@ -147,7 +147,7 @@ jobs: with: repository: 0xPolygonHermez/pil2-compiler token: ${{ secrets.GITHUB_TOKEN }} - ref: develop + ref: feature/custom-cols path: pil2-compiler - name: Install pil2-compiler dependencies diff --git a/cli/assets/templates/pil_helpers_trace.rs.tt b/cli/assets/templates/pil_helpers_trace.rs.tt index bd42773d..60432982 100644 --- a/cli/assets/templates/pil_helpers_trace.rs.tt +++ b/cli/assets/templates/pil_helpers_trace.rs.tt @@ -6,4 +6,10 @@ pub use proofman_macros::trace; trace!({ air.name }Row, { air.name }Trace \{ {{ for column in air.columns }} { column.name }: { column.type },{{ endfor }} }); -{{ endfor }}{{ endfor }} \ No newline at end of file +{{ endfor }}{{ endfor }} + +{{ for air_group in air_groups }}{{ for air in air_group.airs }}{{ for custom_commit in air.custom_columns }} +trace!({ air.name }{custom_commit.name}Row, { air.name }{custom_commit.name}Trace \{ +{{ for column in custom_commit.custom_columns }} { column.name }: { column.type },{{ endfor }} +}); +{{ endfor }}{{ endfor }}{{ endfor }} \ No newline at end of file diff --git a/cli/src/commands/pil_helpers.rs b/cli/src/commands/pil_helpers.rs index c7921dc1..85f24ec6 100644 --- a/cli/src/commands/pil_helpers.rs +++ b/cli/src/commands/pil_helpers.rs @@ -50,8 +50,14 @@ struct AirCtx { name: String, num_rows: u32, columns: Vec, + custom_columns: Vec, } +#[derive(Debug, Serialize)] +struct CustomCommitsCtx { + name: String, + custom_columns: Vec, +} #[derive(Debug, Serialize)] struct ColumnCtx { name: String, @@ -111,6 +117,7 @@ impl PilHelpersCmd { name: air.name.as_ref().unwrap().clone(), num_rows: air.num_rows.unwrap(), columns: Vec::new(), + custom_columns: Vec::new(), }) .collect(), }); @@ -139,6 +146,16 @@ impl PilHelpersCmd { // Build columns data for traces for (airgroup_id, airgroup) in pilout.air_groups.iter().enumerate() { for (air_id, _) in airgroup.airs.iter().enumerate() { + let air = wcctxs[airgroup_id].airs.get_mut(air_id).unwrap(); + air.custom_columns = pilout.air_groups[airgroup_id].airs[air_id] + .custom_commits + .iter() + .map(|commit| CustomCommitsCtx { + name: commit.name.clone().unwrap().to_case(Case::Pascal), + custom_columns: Vec::new(), + }) + .collect(); + // Search symbols where airgroup_id == airgroup_id && air_id == air_id && type == WitnessCol pilout .symbols @@ -149,8 +166,8 @@ impl PilHelpersCmd { && symbol.air_id.is_some() && symbol.air_id.unwrap() == air_id as u32 && symbol.stage.is_some() - && symbol.stage.unwrap() == 1 - && symbol.r#type == SymbolType::WitnessCol as i32 + && ((symbol.r#type == SymbolType::WitnessCol as i32 && symbol.stage.unwrap() == 1) + || (symbol.r#type == SymbolType::CustomCol as i32 && symbol.stage.unwrap() == 0)) }) .for_each(|symbol| { let air = wcctxs[airgroup_id].airs.get_mut(air_id).unwrap(); @@ -165,7 +182,13 @@ impl PilHelpersCmd { .rev() .fold("F".to_string(), |acc, &length| format!("[{}; {}]", acc, length)) }; - air.columns.push(ColumnCtx { name: name.to_owned(), r#type }); + if symbol.r#type == SymbolType::WitnessCol as i32 { + air.columns.push(ColumnCtx { name: name.to_owned(), r#type }); + } else { + air.custom_columns[symbol.commit_id.unwrap() as usize] + .custom_columns + .push(ColumnCtx { name: name.to_owned(), r#type }); + } }); } } diff --git a/common/src/air_instance.rs b/common/src/air_instance.rs index d0b119d4..a3bc17af 100644 --- a/common/src/air_instance.rs +++ b/common/src/air_instance.rs @@ -3,6 +3,7 @@ use std::{collections::HashMap, os::raw::c_void, sync::Arc}; use p3_field::Field; use proofman_starks_lib_c::{ get_airval_id_by_name_c, get_n_airgroupvals_c, get_n_airvals_c, get_n_evals_c, get_airgroupval_id_by_name_c, + get_n_custom_commits_c, }; use crate::SetupCtx; @@ -18,6 +19,7 @@ pub struct StepsParams { pub xdivxsub: *mut c_void, pub p_const_pols: *mut c_void, pub p_const_tree: *mut c_void, + pub custom_commits: [*mut c_void; 10], } impl From<&StepsParams> for *mut c_void { @@ -37,12 +39,14 @@ pub struct AirInstance { pub idx: Option, pub global_idx: Option, pub buffer: Vec, + pub custom_commits: Vec>, pub airgroup_values: Vec, pub airvalues: Vec, pub evals: Vec, pub commits_calculated: HashMap, pub airgroupvalue_calculated: HashMap, pub airvalue_calculated: HashMap, + pub custom_commits_calculated: Vec>, } impl AirInstance { @@ -55,6 +59,15 @@ impl AirInstance { ) -> Self { let ps = setup_ctx.get_setup(airgroup_id, air_id); + let custom_commits_calculated = vec![HashMap::new(); get_n_custom_commits_c(ps.p_setup.p_stark_info) as usize]; + + let mut custom_commits = Vec::new(); + + let n_custom_commits = get_n_custom_commits_c(ps.p_setup.p_stark_info); + for _ in 0..n_custom_commits { + custom_commits.push(Vec::new()); + } + AirInstance { airgroup_id, air_id, @@ -63,12 +76,14 @@ impl AirInstance { idx: None, global_idx: None, buffer, + custom_commits, airgroup_values: vec![F::zero(); get_n_airgroupvals_c(ps.p_setup.p_stark_info) as usize * 3], airvalues: vec![F::zero(); get_n_airvals_c(ps.p_setup.p_stark_info) as usize * 3], evals: vec![F::zero(); get_n_evals_c(ps.p_setup.p_stark_info) as usize * 3], commits_calculated: HashMap::new(), airgroupvalue_calculated: HashMap::new(), airvalue_calculated: HashMap::new(), + custom_commits_calculated, } } @@ -76,6 +91,18 @@ impl AirInstance { self.buffer.as_ptr() as *mut u8 } + pub fn get_custom_commits_ptr(&self) -> [*mut c_void; 10] { + let mut ptrs = [std::ptr::null_mut(); 10]; + for (i, custom_commit) in self.custom_commits.iter().enumerate() { + ptrs[i] = custom_commit.as_ptr() as *mut c_void; + } + ptrs + } + + pub fn set_custom_commit_id_buffer(&mut self, buffer: Vec, commit_id: u64) { + self.custom_commits[commit_id as usize] = buffer; + } + pub fn set_airvalue(&mut self, setup_ctx: &SetupCtx, name: &str, value: F) { let ps = setup_ctx.get_setup(self.airgroup_id, self.air_id); @@ -142,6 +169,10 @@ impl AirInstance { self.commits_calculated.insert(id, true); } + pub fn set_custom_commit_calculated(&mut self, commit_id: usize, id: usize) { + self.custom_commits_calculated[commit_id].insert(id, true); + } + pub fn set_air_instance_id(&mut self, air_instance_id: usize, idx: usize) { self.air_instance_id = Some(air_instance_id); self.idx = Some(idx); diff --git a/common/src/buffer_allocator.rs b/common/src/buffer_allocator.rs index feb905ec..d914dda8 100644 --- a/common/src/buffer_allocator.rs +++ b/common/src/buffer_allocator.rs @@ -10,4 +10,12 @@ pub trait BufferAllocator: Send + Sync { airgroup_id: usize, air_id: usize, ) -> Result<(u64, Vec), Box>; + + fn get_buffer_info_custom_commit( + &self, + sctx: &SetupCtx, + airgroup_id: usize, + air_id: usize, + custom_commit_name: &str, + ) -> Result<(u64, Vec, u64), Box>; } diff --git a/common/src/global_info.rs b/common/src/global_info.rs index 2a9239b4..695d340d 100644 --- a/common/src/global_info.rs +++ b/common/src/global_info.rs @@ -12,6 +12,14 @@ pub struct ProofValueMap { #[serde(default)] pub id: u64, } +#[derive(Clone, Deserialize)] +pub struct PublicMap { + pub name: String, + #[serde(default)] + pub stage: u64, + #[serde(default)] + pub lengths: Vec, +} #[derive(Clone, Deserialize)] pub struct GlobalInfo { @@ -35,6 +43,9 @@ pub struct GlobalInfo { #[serde(rename = "proofValuesMap")] pub proof_values_map: Option>, + + #[serde(rename = "publicsMap")] + pub publics_map: Option>, } #[derive(Clone, Deserialize)] diff --git a/common/src/proof_ctx.rs b/common/src/proof_ctx.rs index e7405676..a434d7cc 100644 --- a/common/src/proof_ctx.rs +++ b/common/src/proof_ctx.rs @@ -8,11 +8,21 @@ use crate::{AirInstancesRepository, GlobalInfo, VerboseMode, WitnessPilout}; pub struct PublicInputs { pub inputs: RwLock>, + pub inputs_set: RwLock>, } impl Default for PublicInputs { fn default() -> Self { - Self { inputs: RwLock::new(Vec::new()) } + Self { inputs: RwLock::new(Vec::new()), inputs_set: RwLock::new(Vec::new()) } + } +} + +impl PublicInputs { + pub fn new(n_publics: usize) -> Self { + Self { + inputs: RwLock::new(vec![0; n_publics * std::mem::size_of::()]), + inputs_set: RwLock::new(vec![false; n_publics]), + } } } @@ -83,11 +93,11 @@ impl ProofCtx { values: RwLock::new(vec![F::zero(); global_info.n_proof_values * 3]), values_set: RwLock::new(HashMap::new()), }; - + let n_publics = global_info.n_publics; Self { pilout, global_info, - public_inputs: PublicInputs::default(), + public_inputs: PublicInputs::new(n_publics), proof_values, challenges: Challenges::default(), buff_helper: BuffHelper::default(), @@ -146,4 +156,39 @@ impl ProofCtx { pub fn set_proof_value_calculated(&self, id: usize) { self.proof_values.values_set.write().unwrap().insert(id, true); } + + pub fn set_public_value(&self, value: u64, public_id: u64) { + self.public_inputs.inputs.write().unwrap()[(public_id as usize) * 8..(public_id as usize + 1) * 8] + .copy_from_slice(&value.to_le_bytes()); + + self.public_inputs.inputs_set.write().unwrap()[public_id as usize] = true; + } + + pub fn set_public_value_by_name(&self, value: u64, public_name: &str) { + let n_publics: usize = self.global_info.publics_map.as_ref().expect("REASON").len(); + let public_id = (0..n_publics) + .find(|&i| { + let public = self.global_info.publics_map.as_ref().expect("REASON").get(i).unwrap(); + public.name == public_name + }) + .unwrap_or_else(|| panic!("Name {} not found in publics_map", public_name)); + + self.set_public_value(value, public_id as u64); + } + + pub fn get_public_value(&self, public_name: &str) -> u64 { + let n_publics: usize = self.global_info.publics_map.as_ref().expect("REASON").len(); + let public_id = (0..n_publics) + .find(|&i| { + let public = self.global_info.publics_map.as_ref().expect("REASON").get(i).unwrap(); + public.name == public_name + }) + .unwrap_or_else(|| panic!("Name {} not found in publics_map", public_name)); + + u64::from_le_bytes( + self.public_inputs.inputs.read().unwrap()[public_id * 8..(public_id + 1) * 8] + .try_into() + .expect("Expected 8 bytes for u64"), + ) + } } diff --git a/common/src/prover.rs b/common/src/prover.rs index 869aeb76..6e5c720a 100644 --- a/common/src/prover.rs +++ b/common/src/prover.rs @@ -63,6 +63,7 @@ pub trait Prover { fn num_stages(&self) -> u32; fn get_challenges(&self, stage_id: u32, proof_ctx: Arc>, transcript: &FFITranscript); fn calculate_stage(&mut self, stage_id: u32, setup_ctx: Arc>, proof_ctx: Arc>); + fn check_stage(&self, stage_id: u32, proof_ctx: Arc>); fn commit_stage(&mut self, stage_id: u32, proof_ctx: Arc>) -> ProverStatus; fn calculate_xdivxsub(&mut self, proof_ctx: Arc>); fn calculate_lev(&mut self, proof_ctx: Arc>); diff --git a/common/src/setup.rs b/common/src/setup.rs index 993dd717..63166821 100644 --- a/common/src/setup.rs +++ b/common/src/setup.rs @@ -30,13 +30,13 @@ impl From<&SetupC> for *mut c_void { } #[derive(Debug)] -pub struct ConstPols { - pub const_pols: RwLock>>, +pub struct Pols { + pub values: RwLock>>, } -impl Default for ConstPols { +impl Default for Pols { fn default() -> Self { - Self { const_pols: RwLock::new(Vec::new()) } + Self { values: RwLock::new(Vec::new()) } } } @@ -47,8 +47,8 @@ pub struct Setup { pub airgroup_id: usize, pub air_id: usize, pub p_setup: SetupC, - pub const_pols: ConstPols, - pub const_tree: ConstPols, + pub const_pols: Pols, + pub const_tree: Pols, } impl Setup { @@ -80,8 +80,8 @@ impl Setup { air_id, airgroup_id, p_setup: SetupC { p_stark_info, p_expressions_bin, p_prover_helpers }, - const_pols: ConstPols::default(), - const_tree: ConstPols::default(), + const_pols: Pols::default(), + const_tree: Pols::default(), } } @@ -109,7 +109,7 @@ impl Setup { let p_const_pols_address = const_pols.as_ptr() as *mut c_void; load_const_pols_c(p_const_pols_address, const_pols_path.as_str(), const_size as u64); - *self.const_pols.const_pols.write().unwrap() = const_pols; + *self.const_pols.values.write().unwrap() = const_pols; } pub fn load_const_pols_tree(&self, global_info: &GlobalInfo, setup_type: &ProofType, save_file: bool) { @@ -132,11 +132,11 @@ impl Setup { if PathBuf::from(&const_pols_tree_path).exists() { load_const_tree_c(p_const_tree_address, const_pols_tree_path.as_str(), const_tree_size as u64); } else { - let const_pols = self.const_pols.const_pols.read().unwrap(); + let const_pols = self.const_pols.values.read().unwrap(); let p_const_pols_address = (*const_pols).as_ptr() as *mut c_void; let tree_filename = if save_file { const_pols_tree_path.as_str() } else { "" }; calculate_const_tree_c(p_stark_info, p_const_pols_address, p_const_tree_address, tree_filename); }; - *self.const_tree.const_pols.write().unwrap() = const_tree; + *self.const_tree.values.write().unwrap() = const_tree; } } diff --git a/common/src/setup_ctx.rs b/common/src/setup_ctx.rs index 02410744..157f5a4e 100644 --- a/common/src/setup_ctx.rs +++ b/common/src/setup_ctx.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::ffi::c_void; use std::sync::Arc; +use log::info; use proofman_starks_lib_c::expressions_bin_new_c; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; @@ -19,17 +20,44 @@ pub struct SetupsVadcop { impl SetupsVadcop { pub fn new(global_info: &GlobalInfo, aggregation: bool) -> Self { + info!("Initializing setups"); + timer_start_debug!(INITIALIZING_SETUP); + let sctx: SetupCtx = SetupCtx::new(global_info, &ProofType::Basic); + timer_stop_and_log_debug!(INITIALIZING_SETUP); if aggregation { + timer_start_debug!(INITIALIZING_SETUP_AGGREGATION); + info!("Initializing setups aggregation"); + + timer_start_debug!(INITIALIZING_SETUP_COMPRESSOR); + info!(" ··· Initializing setups compressor"); + let sctx_compressor: SetupCtx = SetupCtx::new(global_info, &ProofType::Compressor); + timer_stop_and_log_debug!(INITIALIZING_SETUP_COMPRESSOR); + + timer_start_debug!(INITIALIZING_SETUP_RECURSIVE1); + info!(" ··· Initializing setups recursive1"); + let sctx_recursive1: SetupCtx = SetupCtx::new(global_info, &ProofType::Recursive1); + timer_stop_and_log_debug!(INITIALIZING_SETUP_RECURSIVE1); + + timer_start_debug!(INITIALIZING_SETUP_RECURSIVE2); + info!(" ··· Initializing setups recursive2"); + let sctx_recursive2: SetupCtx = SetupCtx::new(global_info, &ProofType::Recursive2); + timer_stop_and_log_debug!(INITIALIZING_SETUP_RECURSIVE2); + + timer_start_debug!(INITIALIZING_SETUP_FINAL); + info!(" ··· Initializing setups final"); + let sctx_final: SetupCtx = SetupCtx::new(global_info, &ProofType::Final); + timer_stop_and_log_debug!(INITIALIZING_SETUP_FINAL); + timer_stop_and_log_debug!(INITIALIZING_SETUP_AGGREGATION); SetupsVadcop { - sctx: Arc::new(SetupCtx::new(global_info, &ProofType::Basic)), - sctx_compressor: Some(Arc::new(SetupCtx::new(global_info, &ProofType::Compressor))), - sctx_recursive1: Some(Arc::new(SetupCtx::new(global_info, &ProofType::Recursive1))), - sctx_recursive2: Some(Arc::new(SetupCtx::new(global_info, &ProofType::Recursive2))), - sctx_final: Some(Arc::new(SetupCtx::new(global_info, &ProofType::Final))), + sctx: Arc::new(sctx), + sctx_compressor: Some(Arc::new(sctx_compressor)), + sctx_recursive1: Some(Arc::new(sctx_recursive1)), + sctx_recursive2: Some(Arc::new(sctx_recursive2)), + sctx_final: Some(Arc::new(sctx_final)), } } else { SetupsVadcop { - sctx: Arc::new(SetupCtx::new(global_info, &ProofType::Basic)), + sctx: Arc::new(sctx), sctx_compressor: None, sctx_recursive1: None, sctx_recursive2: None, @@ -41,11 +69,6 @@ impl SetupsVadcop { #[derive(Debug)] pub struct SetupRepository { - // We store the setup in two stages: a partial setup in the first cell and a full setup in the second cell. - // This allows for loading only the partial setup when constant polynomials are not needed, improving performance. - // In C++, same SetupCtx structure is used to store either the partial or full setup for each instance. - // A full setup can be loaded in one or two steps: partial first, then full (which includes constant polynomial data). - // Since the setup is referenced immutably in the repository, we use OnceCell for both the partial and full setups. setups: HashMap<(usize, usize), Setup>, global_bin: Option<*mut c_void>, } diff --git a/examples/fibonacci-square/pil/fibonaccisq.pil b/examples/fibonacci-square/pil/fibonaccisq.pil index 73ffbb81..38ea9a5e 100644 --- a/examples/fibonacci-square/pil/fibonaccisq.pil +++ b/examples/fibonacci-square/pil/fibonaccisq.pil @@ -4,6 +4,7 @@ require "module.pil"; public in1; public in2; public out; +public rom_root[4]; proofval value1; proofval value2; @@ -17,6 +18,7 @@ private function checkProofValues() on final proof checkProofValues(); airtemplate FibonacciSquare(const int N = 2**8) { + commit stage(0) public(rom_root) rom; airval fibo1; airval fibo2; @@ -25,6 +27,9 @@ airtemplate FibonacciSquare(const int N = 2**8) { col fixed L1 = [1,0...]; col witness a,b; + col rom line; + col rom flags; + // Inputs/Outputs L1 * (a - in1) === 0; L1 * (b - in2) === 0; @@ -34,5 +39,7 @@ airtemplate FibonacciSquare(const int N = 2**8) { 2*fibo1 - fibo2 === 0; + line - (flags + 1) === 0; + permutation_assumes(MODULE_ID, [a*a + b*b, b'], 1 - L1'); } \ No newline at end of file diff --git a/examples/fibonacci-square/src/fibonacci.rs b/examples/fibonacci-square/src/fibonacci.rs index 15aefe6b..21a49579 100644 --- a/examples/fibonacci-square/src/fibonacci.rs +++ b/examples/fibonacci-square/src/fibonacci.rs @@ -5,7 +5,7 @@ use proofman::{WitnessManager, WitnessComponent}; use p3_field::PrimeField; -use crate::{FibonacciSquareTrace, FibonacciSquarePublics, Module, FIBONACCI_SQUARE_AIRGROUP_ID, FIBONACCI_SQUARE_AIR_IDS}; +use crate::{FibonacciSquareTrace, FibonacciSquareRomTrace, Module, FIBONACCI_SQUARE_AIRGROUP_ID, FIBONACCI_SQUARE_AIR_IDS}; pub struct FibonacciSquare { module: Arc>, @@ -41,8 +41,9 @@ impl FibonacciSquare { ) -> Result> { log::debug!("{} ··· Starting witness computation stage {}", Self::MY_NAME, 1); - let public_inputs: FibonacciSquarePublics = pctx.public_inputs.inputs.read().unwrap().as_slice().into(); - let (module, mut a, mut b, _out) = public_inputs.inner(); + let module = pctx.get_public_value("mod"); + let mut a = pctx.get_public_value("in1"); + let mut b = pctx.get_public_value("in2"); let (buffer_size, offsets) = ectx.buffer_allocator.as_ref().get_buffer_info( &sctx, @@ -67,7 +68,23 @@ impl FibonacciSquare { trace[i].b = F::from_canonical_u64(b); } - pctx.public_inputs.inputs.write().unwrap()[24..32].copy_from_slice(&b.to_le_bytes()); + let (buffer_size_rom, offsets_rom, commit_id) = ectx.buffer_allocator.as_ref().get_buffer_info_custom_commit( + &sctx, + FIBONACCI_SQUARE_AIRGROUP_ID, + FIBONACCI_SQUARE_AIR_IDS[0], + "rom", + )?; + + let mut buffer_rom = vec![F::zero(); buffer_size_rom as usize]; + + let mut trace_custom_commits = + FibonacciSquareRomTrace::map_buffer(&mut buffer_rom, num_rows, offsets_rom[0] as usize)?; + for i in 0..num_rows { + trace_custom_commits[i].line = F::from_canonical_u64(3 + i as u64); + trace_custom_commits[i].flags = F::from_canonical_u64(2 + i as u64); + } + + pctx.set_public_value_by_name(b, "out"); pctx.set_proof_value("value1", F::from_canonical_u64(5)); pctx.set_proof_value("value2", F::from_canonical_u64(125)); @@ -84,6 +101,7 @@ impl FibonacciSquare { air_instance.set_airvalue(&sctx, "FibonacciSquare.fibo1", F::from_canonical_u64(1)); air_instance.set_airvalue(&sctx, "FibonacciSquare.fibo2", F::from_canonical_u64(2)); air_instance.set_airvalue_ext(&sctx, "FibonacciSquare.fibo3", vec![F::from_canonical_u64(5); 3]); + air_instance.set_custom_commit_id_buffer(buffer_rom, commit_id); let (is_myne, gid) = ectx.dctx.write().unwrap().add_instance(FIBONACCI_SQUARE_AIRGROUP_ID, FIBONACCI_SQUARE_AIR_IDS[0], 1); diff --git a/examples/fibonacci-square/src/fibonacci_lib.rs b/examples/fibonacci-square/src/fibonacci_lib.rs index d0e7fdf1..80d9279d 100644 --- a/examples/fibonacci-square/src/fibonacci_lib.rs +++ b/examples/fibonacci-square/src/fibonacci_lib.rs @@ -57,8 +57,9 @@ impl WitnessLibrary for FibonacciWitness { FibonacciSquarePublics::default() }; - let pi: Vec = public_inputs.into(); - *pctx.public_inputs.inputs.write().unwrap() = pi; + pctx.set_public_value_by_name(public_inputs.module, "mod"); + pctx.set_public_value_by_name(public_inputs.a, "in1"); + pctx.set_public_value_by_name(public_inputs.b, "in2"); wcm.start_proof(pctx, ectx, sctx); } diff --git a/examples/fibonacci-square/src/module.rs b/examples/fibonacci-square/src/module.rs index 5c8118e5..e0f712c4 100644 --- a/examples/fibonacci-square/src/module.rs +++ b/examples/fibonacci-square/src/module.rs @@ -6,7 +6,7 @@ use pil_std_lib::Std; use p3_field::{AbstractField, PrimeField}; use num_bigint::BigInt; -use crate::{FibonacciSquarePublics, ModuleTrace, FIBONACCI_SQUARE_AIRGROUP_ID, MODULE_AIR_IDS}; +use crate::{ModuleTrace, FIBONACCI_SQUARE_AIRGROUP_ID, MODULE_AIR_IDS}; pub struct Module { inputs: Mutex>, @@ -44,8 +44,7 @@ impl Module fn calculate_trace(&self, pctx: Arc>, ectx: Arc>, sctx: Arc>) { log::debug!("{} ··· Starting witness computation stage {}", Self::MY_NAME, 1); - let pi: FibonacciSquarePublics = pctx.public_inputs.inputs.read().unwrap().as_slice().into(); - let module = pi.module; + let module = pctx.get_public_value("mod"); let (buffer_size, offsets) = ectx .buffer_allocator diff --git a/examples/fibonacci-square/src/pil_helpers/traces.rs b/examples/fibonacci-square/src/pil_helpers/traces.rs index b0dd9076..920d9f9f 100644 --- a/examples/fibonacci-square/src/pil_helpers/traces.rs +++ b/examples/fibonacci-square/src/pil_helpers/traces.rs @@ -14,3 +14,7 @@ trace!(ModuleRow, ModuleTrace { trace!(U8AirRow, U8AirTrace { mul: F, }); + +trace!(FibonacciSquareRomRow, FibonacciSquareRomTrace { + line: F, flags: F, +}); diff --git a/examples/fibonacci-square/src/public_inputs.rs b/examples/fibonacci-square/src/public_inputs.rs index 31ac0624..60d7b5c4 100644 --- a/examples/fibonacci-square/src/public_inputs.rs +++ b/examples/fibonacci-square/src/public_inputs.rs @@ -5,36 +5,4 @@ pub struct FibonacciSquarePublics { pub module: u64, pub a: u64, pub b: u64, - pub out: Option, -} - -impl FibonacciSquarePublics { - pub fn inner(&self) -> (u64, u64, u64, Option) { - (self.module, self.a, self.b, self.out) - } -} - -impl From<&[u8]> for FibonacciSquarePublics { - fn from(input_bytes: &[u8]) -> Self { - const U64_SIZE: usize = std::mem::size_of::(); - assert_eq!(input_bytes.len(), U64_SIZE * 4, "Input bytes length must be 4 * size_of::()"); - - FibonacciSquarePublics { - module: u64::from_le_bytes(input_bytes[0..U64_SIZE].try_into().unwrap()), - a: u64::from_le_bytes(input_bytes[U64_SIZE..2 * U64_SIZE].try_into().unwrap()), - b: u64::from_le_bytes(input_bytes[2 * U64_SIZE..3 * U64_SIZE].try_into().unwrap()), - out: Some(u64::from_le_bytes(input_bytes[3 * U64_SIZE..4 * U64_SIZE].try_into().unwrap())), - } - } -} - -impl From for Vec { - fn from(val: FibonacciSquarePublics) -> Self { - let mut bytes = Vec::with_capacity(4 * std::mem::size_of::()); - bytes.extend_from_slice(&val.module.to_le_bytes()); - bytes.extend_from_slice(&val.a.to_le_bytes()); - bytes.extend_from_slice(&val.b.to_le_bytes()); - bytes.extend_from_slice(&val.out.unwrap_or(0).to_le_bytes()); - bytes - } } diff --git a/hints/src/hints.rs b/hints/src/hints.rs index f078365d..5e2e2e5a 100644 --- a/hints/src/hints.rs +++ b/hints/src/hints.rs @@ -681,8 +681,8 @@ pub fn mul_hint_fields( let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; - let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, @@ -694,6 +694,7 @@ pub fn mul_hint_fields( xdivxsub: std::ptr::null_mut(), p_const_pols: const_pols_ptr, p_const_tree: const_tree_ptr, + custom_commits: air_instance.get_custom_commits_ptr(), }; mul_hint_fields_c( @@ -722,8 +723,8 @@ pub fn acc_hint_field( let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; - let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, @@ -735,6 +736,7 @@ pub fn acc_hint_field( xdivxsub: std::ptr::null_mut(), p_const_pols: const_pols_ptr, p_const_tree: const_tree_ptr, + custom_commits: air_instance.get_custom_commits_ptr(), }; let raw_ptr = acc_hint_field_c( @@ -771,8 +773,8 @@ pub fn acc_mul_hint_fields( let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; - let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, @@ -784,6 +786,7 @@ pub fn acc_mul_hint_fields( xdivxsub: std::ptr::null_mut(), p_const_pols: const_pols_ptr, p_const_tree: const_tree_ptr, + custom_commits: air_instance.get_custom_commits_ptr(), }; let raw_ptr = acc_mul_hint_fields_c( @@ -818,8 +821,8 @@ pub fn get_hint_field( let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; - let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, @@ -831,6 +834,7 @@ pub fn get_hint_field( xdivxsub: std::ptr::null_mut(), p_const_pols: const_pols_ptr, p_const_tree: const_tree_ptr, + custom_commits: air_instance.get_custom_commits_ptr(), }; let raw_ptr = get_hint_field_c( @@ -864,8 +868,8 @@ pub fn get_hint_field_a( let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; - let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, @@ -877,6 +881,7 @@ pub fn get_hint_field_a( xdivxsub: std::ptr::null_mut(), p_const_pols: const_pols_ptr, p_const_tree: const_tree_ptr, + custom_commits: air_instance.get_custom_commits_ptr(), }; let raw_ptr = get_hint_field_c( @@ -916,8 +921,8 @@ pub fn get_hint_field_m( let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; - let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, @@ -929,6 +934,7 @@ pub fn get_hint_field_m( xdivxsub: std::ptr::null_mut(), p_const_pols: const_pols_ptr, p_const_tree: const_tree_ptr, + custom_commits: air_instance.get_custom_commits_ptr(), }; let raw_ptr = get_hint_field_c( @@ -980,6 +986,7 @@ pub fn get_hint_field_constant( xdivxsub: std::ptr::null_mut(), p_const_pols: std::ptr::null_mut(), p_const_tree: std::ptr::null_mut(), + custom_commits: [std::ptr::null_mut(); 10], }; let raw_ptr = get_hint_field_c( @@ -1020,6 +1027,7 @@ pub fn get_hint_field_constant_a( xdivxsub: std::ptr::null_mut(), p_const_pols: std::ptr::null_mut(), p_const_tree: std::ptr::null_mut(), + custom_commits: [std::ptr::null_mut(); 10], }; let raw_ptr = get_hint_field_c( @@ -1066,6 +1074,7 @@ pub fn get_hint_field_constant_m( xdivxsub: std::ptr::null_mut(), p_const_pols: std::ptr::null_mut(), p_const_tree: std::ptr::null_mut(), + custom_commits: [std::ptr::null_mut(); 10], }; let raw_ptr = get_hint_field_c( @@ -1117,6 +1126,7 @@ pub fn set_hint_field( xdivxsub: std::ptr::null_mut(), p_const_pols: std::ptr::null_mut(), p_const_tree: std::ptr::null_mut(), + custom_commits: [std::ptr::null_mut(); 10], }; let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); @@ -1149,6 +1159,7 @@ pub fn set_hint_field_val( xdivxsub: std::ptr::null_mut(), p_const_pols: std::ptr::null_mut(), p_const_tree: std::ptr::null_mut(), + custom_commits: [std::ptr::null_mut(); 10], }; let setup = setup_ctx.get_setup(air_instance.airgroup_id, air_instance.air_id); @@ -1232,8 +1243,8 @@ pub fn print_by_name( let public_inputs_ptr = (*proof_ctx.public_inputs.inputs.read().unwrap()).as_ptr() as *mut c_void; let challenges_ptr = (*proof_ctx.challenges.challenges.read().unwrap()).as_ptr() as *mut c_void; - let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, @@ -1245,6 +1256,7 @@ pub fn print_by_name( xdivxsub: std::ptr::null_mut(), p_const_pols: const_pols_ptr, p_const_tree: const_tree_ptr, + custom_commits: air_instance.get_custom_commits_ptr(), }; let mut lengths_vec = lengths.unwrap_or_default(); diff --git a/pil2-stark/lib/include/starks_lib.h b/pil2-stark/lib/include/starks_lib.h index 8afdf5d6..746e1874 100644 --- a/pil2-stark/lib/include/starks_lib.h +++ b/pil2-stark/lib/include/starks_lib.h @@ -28,10 +28,13 @@ uint64_t get_stark_info_n(void *pStarkInfo); uint64_t get_stark_info_n_publics(void *pStarkInfo); uint64_t get_map_total_n(void *pStarkInfo); + uint64_t get_custom_commit_id(void *pStarkInfo, char* name); + uint64_t get_map_total_n_custom_commits(void *pStarkInfo, uint64_t commit_id); uint64_t get_map_offsets(void *pStarkInfo, char *stage, bool flag); uint64_t get_n_airvals(void *pStarkInfo); uint64_t get_n_airgroupvals(void *pStarkInfo); uint64_t get_n_evals(void *pStarkInfo); + uint64_t get_n_custom_commits(void *pStarkInfo); int64_t get_airvalue_id_by_name(void *pStarkInfo, char* airValueName); int64_t get_airgroupvalue_id_by_name(void *pStarkInfo, char* airValueName); void stark_info_free(void *pStarkInfo); @@ -68,6 +71,7 @@ void starks_free(void *pStarks); void treesGL_get_root(void *pStarks, uint64_t index, void *root); + void treesGL_set_root(void *pStarks, uint64_t index, void *pProof); void calculate_xdivxsub(void *pStarks, void* xiChallenge, void *xDivXSub); void *get_fri_pol(void *pStarkInfo, void *buffer); @@ -76,6 +80,8 @@ void calculate_quotient_polynomial(void *pStarks, void* stepsParams); void calculate_impols_expressions(void *pStarks, uint64_t step, void* stepsParams); + void extend_and_merkelize_custom_commit(void *pStarks, uint64_t commitId, uint64_t step, void *buffer, void *pProof, void *pBuffHelper); + void commit_stage(void *pStarks, uint32_t elementType, uint64_t step, void *buffer, void *pProof, void *pBuffHelper); void compute_lev(void *pStarks, void *xiChallenge, void* LEv); diff --git a/pil2-stark/src/api/starks_api.cpp b/pil2-stark/src/api/starks_api.cpp index e29fcfa2..2e47f012 100644 --- a/pil2-stark/src/api/starks_api.cpp +++ b/pil2-stark/src/api/starks_api.cpp @@ -160,6 +160,27 @@ uint64_t get_map_total_n(void *pStarkInfo) return ((StarkInfo *)pStarkInfo)->mapTotalN; } +uint64_t get_custom_commit_id(void *pStarkInfo, char* name) { + auto starkInfo = *(StarkInfo *)pStarkInfo; + + auto commitId = std::find_if(starkInfo.customCommits.begin(), starkInfo.customCommits.end(), [name](const CustomCommits& customCommit) { + return customCommit.name == string(name); + }); + + if(commitId == starkInfo.customCommits.end()) { + zklog.error("Custom commit " + string(name) + " not found in custom commits."); + exitProcess(); + exit(-1); + } + + return std::distance(starkInfo.customCommits.begin(), commitId); +}; + +uint64_t get_map_total_n_custom_commits(void *pStarkInfo, uint64_t commit_id) { + auto starkInfo = *(StarkInfo *)pStarkInfo; + return starkInfo.mapTotalNcustomCommits[starkInfo.customCommits[commit_id].name]; +} + uint64_t get_n_airvals(void *pStarkInfo) { return ((StarkInfo *)pStarkInfo)->airValuesMap.size(); } @@ -172,6 +193,11 @@ uint64_t get_n_evals(void *pStarkInfo) { return ((StarkInfo *)pStarkInfo)->evMap.size(); } +uint64_t get_n_custom_commits(void *pStarkInfo) { + auto starkInfo = *(StarkInfo *)pStarkInfo; + return starkInfo.customCommitsMap.size(); +} + int64_t get_airvalue_id_by_name(void *pStarkInfo, char* airValueName) { auto starkInfo = *(StarkInfo *)pStarkInfo; for(uint64_t i = 0; i < starkInfo.airValuesMap.size(); ++i) { @@ -312,6 +338,14 @@ void treesGL_get_root(void *pStarks, uint64_t index, void *dst) starks->ffi_treesGL_get_root(index, (Goldilocks::Element *)dst); } +void treesGL_set_root(void *pStarks, uint64_t index, void *pProof) +{ + Starks *starks = (Starks *)pStarks; + + starks->ffi_treesGL_set_root(index, *(FRIProof *)pProof); +} + + void calculate_fri_polynomial(void *pStarks, void* stepsParams) { Starks *starks = (Starks *)pStarks; @@ -331,6 +365,12 @@ void calculate_impols_expressions(void *pStarks, uint64_t step, void* stepsParam starks->calculateImPolsExpressions(step, *(StepsParams *)stepsParams); } +void extend_and_merkelize_custom_commit(void *pStarks, uint64_t commitId, uint64_t step, void *buffer, void *pProof, void *pBuffHelper) +{ + Starks *starks = (Starks *)pStarks; + starks->extendAndMerkelizeCustomCommit(commitId, step, (Goldilocks::Element *)buffer, *(FRIProof *)pProof, (Goldilocks::Element *)pBuffHelper); +} + void commit_stage(void *pStarks, uint32_t elementType, uint64_t step, void *buffer, void *pProof, void *pBuffHelper) { // type == 1 => Goldilocks // type == 2 => BN128 diff --git a/pil2-stark/src/api/starks_api.hpp b/pil2-stark/src/api/starks_api.hpp index 8afdf5d6..746e1874 100644 --- a/pil2-stark/src/api/starks_api.hpp +++ b/pil2-stark/src/api/starks_api.hpp @@ -28,10 +28,13 @@ uint64_t get_stark_info_n(void *pStarkInfo); uint64_t get_stark_info_n_publics(void *pStarkInfo); uint64_t get_map_total_n(void *pStarkInfo); + uint64_t get_custom_commit_id(void *pStarkInfo, char* name); + uint64_t get_map_total_n_custom_commits(void *pStarkInfo, uint64_t commit_id); uint64_t get_map_offsets(void *pStarkInfo, char *stage, bool flag); uint64_t get_n_airvals(void *pStarkInfo); uint64_t get_n_airgroupvals(void *pStarkInfo); uint64_t get_n_evals(void *pStarkInfo); + uint64_t get_n_custom_commits(void *pStarkInfo); int64_t get_airvalue_id_by_name(void *pStarkInfo, char* airValueName); int64_t get_airgroupvalue_id_by_name(void *pStarkInfo, char* airValueName); void stark_info_free(void *pStarkInfo); @@ -68,6 +71,7 @@ void starks_free(void *pStarks); void treesGL_get_root(void *pStarks, uint64_t index, void *root); + void treesGL_set_root(void *pStarks, uint64_t index, void *pProof); void calculate_xdivxsub(void *pStarks, void* xiChallenge, void *xDivXSub); void *get_fri_pol(void *pStarkInfo, void *buffer); @@ -76,6 +80,8 @@ void calculate_quotient_polynomial(void *pStarks, void* stepsParams); void calculate_impols_expressions(void *pStarks, uint64_t step, void* stepsParams); + void extend_and_merkelize_custom_commit(void *pStarks, uint64_t commitId, uint64_t step, void *buffer, void *pProof, void *pBuffHelper); + void commit_stage(void *pStarks, uint32_t elementType, uint64_t step, void *buffer, void *pProof, void *pBuffHelper); void compute_lev(void *pStarks, void *xiChallenge, void* LEv); diff --git a/pil2-stark/src/starkpil/expressions_avx.hpp b/pil2-stark/src/starkpil/expressions_avx.hpp index 62190c3b..b0130cc0 100644 --- a/pil2-stark/src/starkpil/expressions_avx.hpp +++ b/pil2-stark/src/starkpil/expressions_avx.hpp @@ -15,33 +15,47 @@ class ExpressionsAvx : public ExpressionsCtx { void setBufferTInfo(bool domainExtended, int64_t expId) { uint64_t nOpenings = setupCtx.starkInfo.openingPoints.size(); - offsetsStages.resize((setupCtx.starkInfo.nStages + 2)*nOpenings + 1); - nColsStages.resize((setupCtx.starkInfo.nStages + 2)*nOpenings + 1); - nColsStagesAcc.resize((setupCtx.starkInfo.nStages + 2)*nOpenings + 1); + uint64_t ns = 2 + setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size(); + offsetsStages.resize(ns*nOpenings + 1); + nColsStages.resize(ns*nOpenings + 1); + nColsStagesAcc.resize(ns*nOpenings + 1); nCols = setupCtx.starkInfo.nConstants; - uint64_t ns = setupCtx.starkInfo.nStages + 2; + for(uint64_t o = 0; o < nOpenings; ++o) { - for(uint64_t stage = 0; stage <= ns; ++stage) { - std::string section = stage == 0 ? "const" : "cm" + to_string(stage); - offsetsStages[(setupCtx.starkInfo.nStages + 2)*o + stage] = setupCtx.starkInfo.mapOffsets[std::make_pair(section, domainExtended)]; - nColsStages[(setupCtx.starkInfo.nStages + 2)*o + stage] = setupCtx.starkInfo.mapSectionsN[section]; - nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*o + stage] = stage == 0 && o == 0 ? 0 : nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*o + stage - 1] + nColsStages[stage - 1]; + for(uint64_t stage = 0; stage < ns; ++stage) { + if(stage == 0) { + offsetsStages[ns*o] = 0; + nColsStages[ns*o] = setupCtx.starkInfo.mapSectionsN["const"]; + nColsStagesAcc[ns*o] = o == 0 ? 0 : nColsStagesAcc[ns*o + stage - 1] + nColsStages[stage - 1]; + } else if(stage < 2 + setupCtx.starkInfo.nStages) { + std::string section = "cm" + to_string(stage); + offsetsStages[ns*o + stage] = setupCtx.starkInfo.mapOffsets[std::make_pair(section, domainExtended)]; + nColsStages[ns*o + stage] = setupCtx.starkInfo.mapSectionsN[section]; + nColsStagesAcc[ns*o + stage] = nColsStagesAcc[ns*o + stage - 1] + nColsStages[stage - 1]; + } else { + uint64_t index = stage - setupCtx.starkInfo.nStages - 2; + std::string section = setupCtx.starkInfo.customCommits[index].name + "0"; + offsetsStages[ns*o + stage] = setupCtx.starkInfo.mapOffsets[std::make_pair(section, domainExtended)]; + nColsStages[ns*o + stage] = setupCtx.starkInfo.mapSectionsN[section]; + nColsStagesAcc[ns*o + stage] = nColsStagesAcc[ns*o + stage - 1] + nColsStages[stage - 1]; + } } } - nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] = nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings - 1] + nColsStages[(setupCtx.starkInfo.nStages + 2)*nOpenings - 1]; + nColsStagesAcc[ns*nOpenings] = nColsStagesAcc[ns*nOpenings - 1] + nColsStages[ns*nOpenings - 1]; if(expId == int64_t(setupCtx.starkInfo.cExpId)) { - nCols = nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + setupCtx.starkInfo.boundaries.size() + 1; + nCols = nColsStagesAcc[ns*nOpenings] + setupCtx.starkInfo.boundaries.size() + 1; } else if(expId == int64_t(setupCtx.starkInfo.friExpId)) { - nCols = nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + nOpenings*FIELD_EXTENSION; + nCols = nColsStagesAcc[ns*nOpenings] + nOpenings*FIELD_EXTENSION; } else { - nCols = nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + 1; + nCols = nColsStagesAcc[ns*nOpenings] + 1; } } inline void loadPolynomials(StepsParams& params, ParserArgs &parserArgs, std::vector &dests, __m256i *bufferT_, uint64_t row, uint64_t domainSize) { uint64_t nOpenings = setupCtx.starkInfo.openingPoints.size(); + uint64_t ns = 2 + setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size(); bool domainExtended = domainSize == uint64_t(1 << setupCtx.starkInfo.starkStruct.nBitsExt) ? true : false; uint64_t extendBits = (setupCtx.starkInfo.starkStruct.nBitsExt - setupCtx.starkInfo.starkStruct.nBits); @@ -56,6 +70,10 @@ class ExpressionsAvx : public ExpressionsCtx { std::vector constPolsUsed(setupCtx.starkInfo.constPolsMap.size(), false); std::vector cmPolsUsed(setupCtx.starkInfo.cmPolsMap.size(), false); + std::vector> customCommitsUsed(setupCtx.starkInfo.customCommits.size()); + for(uint64_t i = 0; i < setupCtx.starkInfo.customCommits.size(); ++i) { + customCommitsUsed[i] = std::vector(setupCtx.starkInfo.customCommits[i].stageWidths[0], false); + } for(uint64_t i = 0; i < dests.size(); ++i) { for(uint64_t j = 0; j < dests[i].params.size(); ++j) { @@ -74,6 +92,13 @@ class ExpressionsAvx : public ExpressionsCtx { for(uint64_t k = 0; k < dests[i].params[j].parserParams.nCmPolsUsed; ++k) { cmPolsUsed[cmUsed[k]] = true; } + + for(uint64_t k = 0; k < setupCtx.starkInfo.customCommits.size(); ++k) { + uint16_t* customCmUsed = &parserArgs.customCommitsPolsIds[dests[i].params[j].parserParams.customCommitsOffset[k]]; + for(uint64_t l = 0; l < dests[i].params[j].parserParams.nCustomCommitsPolsUsed[k]; ++l) { + customCommitsUsed[k][customCmUsed[l]] = true; + } + } } } } @@ -86,7 +111,7 @@ class ExpressionsAvx : public ExpressionsCtx { uint64_t l = (row + j + nextStrides[o]) % domainSize; bufferT[nrowsPack*o + j] = constPols[l * nColsStages[0] + k]; } - Goldilocks::load_avx(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*o] + k], &bufferT[nrowsPack*o]); + Goldilocks::load_avx(bufferT_[nColsStagesAcc[ns*o] + k], &bufferT[nrowsPack*o]); } } @@ -101,7 +126,25 @@ class ExpressionsAvx : public ExpressionsCtx { uint64_t l = (row + j + nextStrides[o]) % domainSize; bufferT[nrowsPack*o + j] = params.pols[offsetsStages[stage] + l * nColsStages[stage] + stagePos + d]; } - Goldilocks::load_avx(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*o + stage] + (stagePos + d)], &bufferT[nrowsPack*o]); + Goldilocks::load_avx(bufferT_[nColsStagesAcc[ns*o + stage] + (stagePos + d)], &bufferT[nrowsPack*o]); + } + } + } + + for(uint64_t i = 0; i < setupCtx.starkInfo.customCommits.size(); ++i) { + for(uint64_t j = 0; j < setupCtx.starkInfo.customCommits[i].stageWidths[0]; ++j) { + if(!customCommitsUsed[i][j]) continue; + PolMap polInfo = setupCtx.starkInfo.customCommitsMap[i][j]; + uint64_t stage = setupCtx.starkInfo.nStages + 2 + i; + uint64_t stagePos = polInfo.stagePos; + for(uint64_t d = 0; d < polInfo.dim; ++d) { + for(uint64_t o = 0; o < nOpenings; ++o) { + for(uint64_t j = 0; j < nrowsPack; ++j) { + uint64_t l = (row + j + nextStrides[o]) % domainSize; + bufferT[nrowsPack*o + j] = params.customCommits[i][offsetsStages[stage] + l * nColsStages[stage] + stagePos + d]; + } + Goldilocks::load_avx(bufferT_[nColsStagesAcc[ns*o + stage] + (stagePos + d)], &bufferT[nrowsPack*o]); + } } } } @@ -110,12 +153,12 @@ class ExpressionsAvx : public ExpressionsCtx { for(uint64_t j = 0; j < nrowsPack; ++j) { bufferT[j] = setupCtx.proverHelpers.x_2ns[row + j]; } - Goldilocks::load_avx(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings]], &bufferT[0]); + Goldilocks::load_avx(bufferT_[nColsStagesAcc[ns*nOpenings]], &bufferT[0]); for(uint64_t d = 0; d < setupCtx.starkInfo.boundaries.size(); ++d) { for(uint64_t j = 0; j < nrowsPack; ++j) { bufferT[j] = setupCtx.proverHelpers.zi[row + j + d*domainSize]; } - Goldilocks::load_avx(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + 1 + d], &bufferT[0]); + Goldilocks::load_avx(bufferT_[nColsStagesAcc[ns*nOpenings] + 1 + d], &bufferT[0]); } } else if(dests[0].params[0].parserParams.expId == int64_t(setupCtx.starkInfo.friExpId)) { for(uint64_t d = 0; d < setupCtx.starkInfo.openingPoints.size(); ++d) { @@ -123,14 +166,14 @@ class ExpressionsAvx : public ExpressionsCtx { for(uint64_t j = 0; j < nrowsPack; ++j) { bufferT[j] = params.xDivXSub[(row + j + d*domainSize)*FIELD_EXTENSION + k]; } - Goldilocks::load_avx(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + d*FIELD_EXTENSION + k], &bufferT[0]); + Goldilocks::load_avx(bufferT_[nColsStagesAcc[ns*nOpenings] + d*FIELD_EXTENSION + k], &bufferT[0]); } } } else { for(uint64_t j = 0; j < nrowsPack; ++j) { bufferT[j] = setupCtx.proverHelpers.x_n[row + j]; } - Goldilocks::load_avx(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings]], &bufferT[0]); + Goldilocks::load_avx(bufferT_[nColsStagesAcc[ns*nOpenings]], &bufferT[0]); } } @@ -245,6 +288,7 @@ class ExpressionsAvx : public ExpressionsCtx { void calculateExpressions(StepsParams& params, ParserArgs &parserArgs, std::vector dests, uint64_t domainSize) override { uint64_t nOpenings = setupCtx.starkInfo.openingPoints.size(); + uint64_t ns = 2 + setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size(); bool domainExtended = domainSize == uint64_t(1 << setupCtx.starkInfo.starkStruct.nBitsExt) ? true : false; uint64_t expId = dests[0].params[0].op == opType::tmp ? dests[0].params[0].parserParams.destDim : 0; @@ -306,7 +350,7 @@ class ExpressionsAvx : public ExpressionsCtx { auto openingPointZero = std::find_if(setupCtx.starkInfo.openingPoints.begin(), setupCtx.starkInfo.openingPoints.end(), [](int p) { return p == 0; }); auto openingPointZeroIndex = std::distance(setupCtx.starkInfo.openingPoints.begin(), openingPointZero); - uint64_t buffPos = (setupCtx.starkInfo.nStages + 2)*openingPointZeroIndex + dests[j].params[k].stage; + uint64_t buffPos = ns*openingPointZeroIndex + dests[j].params[k].stage; uint64_t stagePos = dests[j].params[k].stagePos; copyPolynomial(&destVals[j][k*FIELD_EXTENSION], dests[j].params[k].inverse, dests[j].params[k].dim, &bufferT_[nColsStagesAcc[buffPos] + stagePos]); continue; diff --git a/pil2-stark/src/starkpil/expressions_avx512.hpp b/pil2-stark/src/starkpil/expressions_avx512.hpp index 0e1a15c3..e68c5063 100644 --- a/pil2-stark/src/starkpil/expressions_avx512.hpp +++ b/pil2-stark/src/starkpil/expressions_avx512.hpp @@ -15,33 +15,47 @@ class ExpressionsAvx512 : public ExpressionsCtx { void setBufferTInfo(bool domainExtended, int64_t expId) { uint64_t nOpenings = setupCtx.starkInfo.openingPoints.size(); - offsetsStages.resize((setupCtx.starkInfo.nStages + 2)*nOpenings + 1); - nColsStages.resize((setupCtx.starkInfo.nStages + 2)*nOpenings + 1); - nColsStagesAcc.resize((setupCtx.starkInfo.nStages + 2)*nOpenings + 1); + uint64_t ns = 2 + setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size(); + offsetsStages.resize(ns*nOpenings + 1); + nColsStages.resize(ns*nOpenings + 1); + nColsStagesAcc.resize(ns*nOpenings + 1); nCols = setupCtx.starkInfo.nConstants; - uint64_t ns = setupCtx.starkInfo.nStages + 2; + for(uint64_t o = 0; o < nOpenings; ++o) { - for(uint64_t stage = 0; stage <= ns; ++stage) { - std::string section = stage == 0 ? "const" : "cm" + to_string(stage); - offsetsStages[(setupCtx.starkInfo.nStages + 2)*o + stage] = setupCtx.starkInfo.mapOffsets[std::make_pair(section, domainExtended)]; - nColsStages[(setupCtx.starkInfo.nStages + 2)*o + stage] = setupCtx.starkInfo.mapSectionsN[section]; - nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*o + stage] = stage == 0 && o == 0 ? 0 : nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*o + stage - 1] + nColsStages[stage - 1]; + for(uint64_t stage = 0; stage < ns; ++stage) { + if(stage == 0) { + offsetsStages[ns*o] = 0; + nColsStages[ns*o] = setupCtx.starkInfo.mapSectionsN["const"]; + nColsStagesAcc[ns*o] = o == 0 ? 0 : nColsStagesAcc[ns*o + stage - 1] + nColsStages[stage - 1]; + } else if(stage < 2 + setupCtx.starkInfo.nStages) { + std::string section = "cm" + to_string(stage); + offsetsStages[ns*o + stage] = setupCtx.starkInfo.mapOffsets[std::make_pair(section, domainExtended)]; + nColsStages[ns*o + stage] = setupCtx.starkInfo.mapSectionsN[section]; + nColsStagesAcc[ns*o + stage] = nColsStagesAcc[ns*o + stage - 1] + nColsStages[stage - 1]; + } else { + uint64_t index = stage - setupCtx.starkInfo.nStages - 2; + std::string section = setupCtx.starkInfo.customCommits[index].name + "0"; + offsetsStages[ns*o + stage] = setupCtx.starkInfo.mapOffsets[std::make_pair(section, domainExtended)]; + nColsStages[ns*o + stage] = setupCtx.starkInfo.mapSectionsN[section]; + nColsStagesAcc[ns*o + stage] = nColsStagesAcc[ns*o + stage - 1] + nColsStages[stage - 1]; + } } } - nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] = nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings - 1] + nColsStages[(setupCtx.starkInfo.nStages + 2)*nOpenings - 1]; + nColsStagesAcc[ns*nOpenings] = nColsStagesAcc[ns*nOpenings - 1] + nColsStages[ns*nOpenings - 1]; if(expId == int64_t(setupCtx.starkInfo.cExpId)) { - nCols = nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + setupCtx.starkInfo.boundaries.size() + 1; + nCols = nColsStagesAcc[ns*nOpenings] + setupCtx.starkInfo.boundaries.size() + 1; } else if(expId == int64_t(setupCtx.starkInfo.friExpId)) { - nCols = nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + nOpenings*FIELD_EXTENSION; + nCols = nColsStagesAcc[ns*nOpenings] + nOpenings*FIELD_EXTENSION; } else { - nCols = nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + 1; + nCols = nColsStagesAcc[ns*nOpenings] + 1; } } inline void loadPolynomials(StepsParams& params, ParserArgs &parserArgs, std::vector &dests, __m512i *bufferT_, uint64_t row, uint64_t domainSize) { uint64_t nOpenings = setupCtx.starkInfo.openingPoints.size(); + uint64_t ns = 2 + setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size(); bool domainExtended = domainSize == uint64_t(1 << setupCtx.starkInfo.starkStruct.nBitsExt) ? true : false; uint64_t extendBits = (setupCtx.starkInfo.starkStruct.nBitsExt - setupCtx.starkInfo.starkStruct.nBits); @@ -56,6 +70,10 @@ class ExpressionsAvx512 : public ExpressionsCtx { std::vector constPolsUsed(setupCtx.starkInfo.constPolsMap.size(), false); std::vector cmPolsUsed(setupCtx.starkInfo.cmPolsMap.size(), false); + std::vector> customCommitsUsed(setupCtx.starkInfo.customCommits.size()); + for(uint64_t i = 0; i < setupCtx.starkInfo.customCommits.size(); ++i) { + customCommitsUsed[i] = std::vector(setupCtx.starkInfo.customCommits[i].stageWidths[0], false); + } for(uint64_t i = 0; i < dests.size(); ++i) { for(uint64_t j = 0; j < dests[i].params.size(); ++j) { @@ -74,6 +92,13 @@ class ExpressionsAvx512 : public ExpressionsCtx { for(uint64_t k = 0; k < dests[i].params[j].parserParams.nCmPolsUsed; ++k) { cmPolsUsed[cmUsed[k]] = true; } + + for(uint64_t k = 0; k < setupCtx.starkInfo.customCommits.size(); ++k) { + uint16_t* customCmUsed = &parserArgs.customCommitsPolsIds[dests[i].params[j].parserParams.customCommitsOffset[k]]; + for(uint64_t l = 0; l < dests[i].params[j].parserParams.nCustomCommitsPolsUsed[k]; ++l) { + customCommitsUsed[k][customCmUsed[l]] = true; + } + } } } } @@ -86,7 +111,7 @@ class ExpressionsAvx512 : public ExpressionsCtx { uint64_t l = (row + j + nextStrides[o]) % domainSize; bufferT[nrowsPack*o + j] = constPols[l * nColsStages[0] + k]; } - Goldilocks::load_avx512(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*o] + k], &bufferT[nrowsPack*o]); + Goldilocks::load_avx512(bufferT_[nColsStagesAcc[ns*o] + k], &bufferT[nrowsPack*o]); } } @@ -101,7 +126,25 @@ class ExpressionsAvx512 : public ExpressionsCtx { uint64_t l = (row + j + nextStrides[o]) % domainSize; bufferT[nrowsPack*o + j] = params.pols[offsetsStages[stage] + l * nColsStages[stage] + stagePos + d]; } - Goldilocks::load_avx512(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*o + stage] + (stagePos + d)], &bufferT[nrowsPack*o]); + Goldilocks::load_avx512(bufferT_[nColsStagesAcc[ns*o + stage] + (stagePos + d)], &bufferT[nrowsPack*o]); + } + } + } + + for(uint64_t i = 0; i < setupCtx.starkInfo.customCommits.size(); ++i) { + for(uint64_t j = 0; j < setupCtx.starkInfo.customCommits[i].stageWidths[0]; ++j) { + if(!customCommitsUsed[i][j]) continue; + PolMap polInfo = setupCtx.starkInfo.customCommitsMap[i][j]; + uint64_t stage = setupCtx.starkInfo.nStages + 2 + i; + uint64_t stagePos = polInfo.stagePos; + for(uint64_t d = 0; d < polInfo.dim; ++d) { + for(uint64_t o = 0; o < nOpenings; ++o) { + for(uint64_t j = 0; j < nrowsPack; ++j) { + uint64_t l = (row + j + nextStrides[o]) % domainSize; + bufferT[nrowsPack*o + j] = params.customCommits[i][offsetsStages[stage] + l * nColsStages[stage] + stagePos + d]; + } + Goldilocks::load_avx(bufferT_[nColsStagesAcc[ns*o + stage] + (stagePos + d)], &bufferT[nrowsPack*o]); + } } } } @@ -110,12 +153,12 @@ class ExpressionsAvx512 : public ExpressionsCtx { for(uint64_t j = 0; j < nrowsPack; ++j) { bufferT[j] = setupCtx.proverHelpers.x_2ns[row + j]; } - Goldilocks::load_avx512(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings]], &bufferT[0]); + Goldilocks::load_avx512(bufferT_[nColsStagesAcc[ns*nOpenings]], &bufferT[0]); for(uint64_t d = 0; d < setupCtx.starkInfo.boundaries.size(); ++d) { for(uint64_t j = 0; j < nrowsPack; ++j) { bufferT[j] = setupCtx.proverHelpers.zi[row + j + d*domainSize]; } - Goldilocks::load_avx512(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + 1 + d], &bufferT[0]); + Goldilocks::load_avx512(bufferT_[nColsStagesAcc[ns*nOpenings] + 1 + d], &bufferT[0]); } } else if(dests[0].params[0].parserParams.expId == int64_t(setupCtx.starkInfo.friExpId)) { for(uint64_t d = 0; d < setupCtx.starkInfo.openingPoints.size(); ++d) { @@ -123,14 +166,14 @@ class ExpressionsAvx512 : public ExpressionsCtx { for(uint64_t j = 0; j < nrowsPack; ++j) { bufferT[j] = params.xDivXSub[(row + j + d*domainSize)*FIELD_EXTENSION + k]; } - Goldilocks::load_avx512(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + d*FIELD_EXTENSION + k], &bufferT[0]); + Goldilocks::load_avx512(bufferT_[nColsStagesAcc[ns*nOpenings] + d*FIELD_EXTENSION + k], &bufferT[0]); } } } else { for(uint64_t j = 0; j < nrowsPack; ++j) { bufferT[j] = setupCtx.proverHelpers.x_n[row + j]; } - Goldilocks::load_avx512(bufferT_[nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings]], &bufferT[0]); + Goldilocks::load_avx512(bufferT_[nColsStagesAcc[ns*nOpenings]], &bufferT[0]); } } @@ -245,6 +288,7 @@ class ExpressionsAvx512 : public ExpressionsCtx { void calculateExpressions(StepsParams& params, ParserArgs &parserArgs, std::vector dests, uint64_t domainSize) override { uint64_t nOpenings = setupCtx.starkInfo.openingPoints.size(); + uint64_t ns = 2 + setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size(); bool domainExtended = domainSize == uint64_t(1 << setupCtx.starkInfo.starkStruct.nBitsExt) ? true : false; uint64_t expId = dests[0].params[0].op == opType::tmp ? dests[0].params[0].parserParams.destDim : 0; @@ -306,7 +350,7 @@ class ExpressionsAvx512 : public ExpressionsCtx { auto openingPointZero = std::find_if(setupCtx.starkInfo.openingPoints.begin(), setupCtx.starkInfo.openingPoints.end(), [](int p) { return p == 0; }); auto openingPointZeroIndex = std::distance(setupCtx.starkInfo.openingPoints.begin(), openingPointZero); - uint64_t buffPos = (setupCtx.starkInfo.nStages + 2)*openingPointZeroIndex + dests[j].params[k].stage; + uint64_t buffPos = ns*openingPointZeroIndex + dests[j].params[k].stage; uint64_t stagePos = dests[j].params[k].stagePos; copyPolynomial(&destVals[j][k*FIELD_EXTENSION], dests[j].params[k].inverse, dests[j].params[k].dim, &bufferT_[nColsStagesAcc[buffPos] + stagePos]); continue; diff --git a/pil2-stark/src/starkpil/expressions_bin.cpp b/pil2-stark/src/starkpil/expressions_bin.cpp index f6b5b21c..8640e3ea 100644 --- a/pil2-stark/src/starkpil/expressions_bin.cpp +++ b/pil2-stark/src/starkpil/expressions_bin.cpp @@ -22,6 +22,7 @@ void ExpressionsBin::loadExpressionsBin(BinFileUtils::BinFile *expressionsBin) { uint32_t nPublicsIdsExpressions = expressionsBin->readU32LE(); uint32_t nAirgroupValuesIdsExpressions = expressionsBin->readU32LE(); uint32_t nAirValuesIdsExpressions = expressionsBin->readU32LE(); + uint64_t nCustomCommitsPolsIdsExpressions = expressionsBin->readU32LE(); expressionsBinArgsExpressions.ops = new uint8_t[nOpsExpressions]; expressionsBinArgsExpressions.args = new uint16_t[nArgsExpressions]; @@ -32,8 +33,10 @@ void ExpressionsBin::loadExpressionsBin(BinFileUtils::BinFile *expressionsBin) { expressionsBinArgsExpressions.publicsIds = new uint16_t[nPublicsIdsExpressions]; expressionsBinArgsExpressions.airgroupValuesIds = new uint16_t[nAirgroupValuesIdsExpressions]; expressionsBinArgsExpressions.airValuesIds = new uint16_t[nAirValuesIdsExpressions]; + expressionsBinArgsExpressions.customCommitsPolsIds = new uint16_t[nCustomCommitsPolsIdsExpressions]; expressionsBinArgsExpressions.nNumbers = nNumbersExpressions; + uint64_t nCustomCommits = expressionsBin->readU32LE(); uint64_t nExpressions = expressionsBin->readU32LE(); for(uint64_t i = 0; i < nExpressions; ++i) { @@ -72,6 +75,15 @@ void ExpressionsBin::loadExpressionsBin(BinFileUtils::BinFile *expressionsBin) { parserParamsExpression.nAirValuesUsed = expressionsBin->readU32LE(); parserParamsExpression.airValuesOffset = expressionsBin->readU32LE(); + + std::vector nCustomCommitsPolsUsed(nCustomCommits); + std::vector customCommitsOffset(nCustomCommits); + for(uint64_t j = 0; j < nCustomCommits; ++j) { + nCustomCommitsPolsUsed[j] = expressionsBin->readU32LE(); + customCommitsOffset[j] = expressionsBin->readU32LE(); + } + parserParamsExpression.nCustomCommitsPolsUsed = nCustomCommitsPolsUsed; + parserParamsExpression.customCommitsOffset = customCommitsOffset; parserParamsExpression.line = expressionsBin->readString(); @@ -112,8 +124,10 @@ void ExpressionsBin::loadExpressionsBin(BinFileUtils::BinFile *expressionsBin) { expressionsBinArgsExpressions.airValuesIds[j] = expressionsBin->readU16LE(); } + for(uint64_t j = 0; j < nCustomCommitsPolsIdsExpressions; ++j) { + expressionsBinArgsExpressions.customCommitsPolsIds[j] = expressionsBin->readU16LE(); + } expressionsBin->endReadSection(); - expressionsBin->startReadSection(BINARY_CONSTRAINTS_SECTION); uint32_t nOpsDebug = expressionsBin->readU32LE(); @@ -125,6 +139,7 @@ void ExpressionsBin::loadExpressionsBin(BinFileUtils::BinFile *expressionsBin) { uint32_t nPublicsIdsDebug = expressionsBin->readU32LE(); uint32_t nAirgroupValuesIdsDebug = expressionsBin->readU32LE(); uint32_t nAirValuesIdsDebug = expressionsBin->readU32LE(); + uint64_t nCustomCommitsPolsIdsDebug = expressionsBin->readU32LE(); expressionsBinArgsConstraints.ops = new uint8_t[nOpsDebug]; expressionsBinArgsConstraints.args = new uint16_t[nArgsDebug]; @@ -135,8 +150,11 @@ void ExpressionsBin::loadExpressionsBin(BinFileUtils::BinFile *expressionsBin) { expressionsBinArgsConstraints.publicsIds = new uint16_t[nPublicsIdsDebug]; expressionsBinArgsConstraints.airgroupValuesIds = new uint16_t[nAirgroupValuesIdsDebug]; expressionsBinArgsConstraints.airValuesIds = new uint16_t[nAirValuesIdsDebug]; + expressionsBinArgsConstraints.customCommitsPolsIds = new uint16_t[nCustomCommitsPolsIdsDebug]; expressionsBinArgsConstraints.nNumbers = nNumbersDebug; + uint64_t nCustomCommitsC = expressionsBin->readU32LE(); + uint32_t nConstraints = expressionsBin->readU32LE(); for(uint64_t i = 0; i < nConstraints; ++i) { @@ -179,6 +197,15 @@ void ExpressionsBin::loadExpressionsBin(BinFileUtils::BinFile *expressionsBin) { parserParamsConstraint.nAirValuesUsed = expressionsBin->readU32LE(); parserParamsConstraint.airValuesOffset = expressionsBin->readU32LE(); + std::vector nCustomCommitsPolsUsedC(nCustomCommitsC); + std::vector customCommitsOffsetC(nCustomCommitsC); + for(uint64_t j = 0; j < nCustomCommitsC; ++j) { + nCustomCommitsPolsUsedC[j] = expressionsBin->readU32LE(); + customCommitsOffsetC[j] = expressionsBin->readU32LE(); + } + parserParamsConstraint.nCustomCommitsPolsUsed = nCustomCommitsPolsUsedC; + parserParamsConstraint.customCommitsOffset = customCommitsOffsetC; + parserParamsConstraint.imPol = bool(expressionsBin->readU32LE()); parserParamsConstraint.line = expressionsBin->readString(); @@ -220,8 +247,10 @@ void ExpressionsBin::loadExpressionsBin(BinFileUtils::BinFile *expressionsBin) { expressionsBinArgsConstraints.airValuesIds[j] = expressionsBin->readU16LE(); } + for(uint64_t j = 0; j < nCustomCommitsPolsIdsDebug; ++j) { + expressionsBinArgsConstraints.customCommitsPolsIds[j] = expressionsBin->readU16LE(); + } expressionsBin->endReadSection(); - expressionsBin->startReadSection(BINARY_HINTS_SECTION); uint32_t nHints = expressionsBin->readU32LE(); diff --git a/pil2-stark/src/starkpil/expressions_bin.hpp b/pil2-stark/src/starkpil/expressions_bin.hpp index d2b2b320..ecec71d2 100644 --- a/pil2-stark/src/starkpil/expressions_bin.hpp +++ b/pil2-stark/src/starkpil/expressions_bin.hpp @@ -72,6 +72,8 @@ struct ParserParams uint32_t airgroupValuesOffset; uint32_t nAirValuesUsed; uint32_t airValuesOffset; + std::vector nCustomCommitsPolsUsed; + std::vector customCommitsOffset; uint32_t firstRow; uint32_t lastRow; uint32_t destDim; @@ -91,6 +93,7 @@ struct ParserArgs uint16_t* publicsIds; uint16_t* airgroupValuesIds; uint16_t* airValuesIds; + uint16_t* customCommitsPolsIds; uint64_t nNumbers; }; @@ -117,6 +120,7 @@ class ExpressionsBin if (expressionsBinArgsExpressions.publicsIds) delete[] expressionsBinArgsExpressions.publicsIds; if (expressionsBinArgsExpressions.airgroupValuesIds) delete[] expressionsBinArgsExpressions.airgroupValuesIds; if (expressionsBinArgsExpressions.airValuesIds) delete[] expressionsBinArgsExpressions.airValuesIds; + if (expressionsBinArgsExpressions.customCommitsPolsIds) delete[] expressionsBinArgsExpressions.customCommitsPolsIds; if (expressionsBinArgsConstraints.ops) delete[] expressionsBinArgsConstraints.ops; if (expressionsBinArgsConstraints.args) delete[] expressionsBinArgsConstraints.args; @@ -127,6 +131,7 @@ class ExpressionsBin if (expressionsBinArgsConstraints.publicsIds) delete[] expressionsBinArgsConstraints.publicsIds; if (expressionsBinArgsConstraints.airgroupValuesIds) delete[] expressionsBinArgsConstraints.airgroupValuesIds; if (expressionsBinArgsConstraints.airValuesIds) delete[] expressionsBinArgsConstraints.airValuesIds; + if (expressionsBinArgsConstraints.customCommitsPolsIds) delete[] expressionsBinArgsConstraints.customCommitsPolsIds; }; /* Constructor */ diff --git a/pil2-stark/src/starkpil/expressions_pack.hpp b/pil2-stark/src/starkpil/expressions_pack.hpp index 4acd8aa3..68980735 100644 --- a/pil2-stark/src/starkpil/expressions_pack.hpp +++ b/pil2-stark/src/starkpil/expressions_pack.hpp @@ -13,33 +13,47 @@ class ExpressionsPack : public ExpressionsCtx { void setBufferTInfo(bool domainExtended, int64_t expId) { uint64_t nOpenings = setupCtx.starkInfo.openingPoints.size(); - offsetsStages.resize((setupCtx.starkInfo.nStages + 2)*nOpenings + 1); - nColsStages.resize((setupCtx.starkInfo.nStages + 2)*nOpenings + 1); - nColsStagesAcc.resize((setupCtx.starkInfo.nStages + 2)*nOpenings + 1); + uint64_t ns = 2 + setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size(); + offsetsStages.resize(ns*nOpenings + 1); + nColsStages.resize(ns*nOpenings + 1); + nColsStagesAcc.resize(ns*nOpenings + 1); nCols = setupCtx.starkInfo.nConstants; - uint64_t ns = setupCtx.starkInfo.nStages + 2; + for(uint64_t o = 0; o < nOpenings; ++o) { - for(uint64_t stage = 0; stage <= ns; ++stage) { - std::string section = stage == 0 ? "const" : "cm" + to_string(stage); - offsetsStages[(setupCtx.starkInfo.nStages + 2)*o + stage] = setupCtx.starkInfo.mapOffsets[std::make_pair(section, domainExtended)]; - nColsStages[(setupCtx.starkInfo.nStages + 2)*o + stage] = setupCtx.starkInfo.mapSectionsN[section]; - nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*o + stage] = stage == 0 && o == 0 ? 0 : nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*o + stage - 1] + nColsStages[stage - 1]; + for(uint64_t stage = 0; stage < ns; ++stage) { + if(stage == 0) { + offsetsStages[ns*o] = 0; + nColsStages[ns*o] = setupCtx.starkInfo.mapSectionsN["const"]; + nColsStagesAcc[ns*o] = o == 0 ? 0 : nColsStagesAcc[ns*o + stage - 1] + nColsStages[stage - 1]; + } else if(stage < 2 + setupCtx.starkInfo.nStages) { + std::string section = "cm" + to_string(stage); + offsetsStages[ns*o + stage] = setupCtx.starkInfo.mapOffsets[std::make_pair(section, domainExtended)]; + nColsStages[ns*o + stage] = setupCtx.starkInfo.mapSectionsN[section]; + nColsStagesAcc[ns*o + stage] = nColsStagesAcc[ns*o + stage - 1] + nColsStages[stage - 1]; + } else { + uint64_t index = stage - setupCtx.starkInfo.nStages - 2; + std::string section = setupCtx.starkInfo.customCommits[index].name + "0"; + offsetsStages[ns*o + stage] = setupCtx.starkInfo.mapOffsets[std::make_pair(section, domainExtended)]; + nColsStages[ns*o + stage] = setupCtx.starkInfo.mapSectionsN[section]; + nColsStagesAcc[ns*o + stage] = nColsStagesAcc[ns*o + stage - 1] + nColsStages[stage - 1]; + } } } - nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] = nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings - 1] + nColsStages[(setupCtx.starkInfo.nStages + 2)*nOpenings - 1]; + nColsStagesAcc[ns*nOpenings] = nColsStagesAcc[ns*nOpenings - 1] + nColsStages[ns*nOpenings - 1]; if(expId == int64_t(setupCtx.starkInfo.cExpId)) { - nCols = nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + setupCtx.starkInfo.boundaries.size() + 1; + nCols = nColsStagesAcc[ns*nOpenings] + setupCtx.starkInfo.boundaries.size() + 1; } else if(expId == int64_t(setupCtx.starkInfo.friExpId)) { - nCols = nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + nOpenings*FIELD_EXTENSION; + nCols = nColsStagesAcc[ns*nOpenings] + nOpenings*FIELD_EXTENSION; } else { - nCols = nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + 1; + nCols = nColsStagesAcc[ns*nOpenings] + 1; } } inline void loadPolynomials(StepsParams& params, ParserArgs &parserArgs, std::vector &dests, Goldilocks::Element *bufferT_, uint64_t row, uint64_t domainSize) { uint64_t nOpenings = setupCtx.starkInfo.openingPoints.size(); + uint64_t ns = 2 + setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size(); bool domainExtended = domainSize == uint64_t(1 << setupCtx.starkInfo.starkStruct.nBitsExt) ? true : false; uint64_t extendBits = (setupCtx.starkInfo.starkStruct.nBitsExt - setupCtx.starkInfo.starkStruct.nBits); @@ -54,6 +68,10 @@ class ExpressionsPack : public ExpressionsCtx { std::vector constPolsUsed(setupCtx.starkInfo.constPolsMap.size(), false); std::vector cmPolsUsed(setupCtx.starkInfo.cmPolsMap.size(), false); + std::vector> customCommitsUsed(setupCtx.starkInfo.customCommits.size()); + for(uint64_t i = 0; i < setupCtx.starkInfo.customCommits.size(); ++i) { + customCommitsUsed[i] = std::vector(setupCtx.starkInfo.customCommits[i].stageWidths[0], false); + } for(uint64_t i = 0; i < dests.size(); ++i) { for(uint64_t j = 0; j < dests[i].params.size(); ++j) { @@ -72,6 +90,13 @@ class ExpressionsPack : public ExpressionsCtx { for(uint64_t k = 0; k < dests[i].params[j].parserParams.nCmPolsUsed; ++k) { cmPolsUsed[cmUsed[k]] = true; } + + for(uint64_t k = 0; k < setupCtx.starkInfo.customCommits.size(); ++k) { + uint16_t* customCmUsed = &parserArgs.customCommitsPolsIds[dests[i].params[j].parserParams.customCommitsOffset[k]]; + for(uint64_t l = 0; l < dests[i].params[j].parserParams.nCustomCommitsPolsUsed[k]; ++l) { + customCommitsUsed[k][customCmUsed[l]] = true; + } + } } } } @@ -80,7 +105,7 @@ class ExpressionsPack : public ExpressionsCtx { for(uint64_t o = 0; o < nOpenings; ++o) { for(uint64_t j = 0; j < nrowsPack; ++j) { uint64_t l = (row + j + nextStrides[o]) % domainSize; - bufferT_[(nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*o] + k)*nrowsPack + j] = constPols[l * nColsStages[0] + k]; + bufferT_[(nColsStagesAcc[ns*o] + k)*nrowsPack + j] = constPols[l * nColsStages[0] + k]; } } } @@ -94,7 +119,24 @@ class ExpressionsPack : public ExpressionsCtx { for(uint64_t o = 0; o < nOpenings; ++o) { for(uint64_t j = 0; j < nrowsPack; ++j) { uint64_t l = (row + j + nextStrides[o]) % domainSize; - bufferT_[(nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*o + stage] + (stagePos + d))*nrowsPack + j] = params.pols[offsetsStages[stage] + l * nColsStages[stage] + stagePos + d]; + bufferT_[(nColsStagesAcc[ns*o + stage] + (stagePos + d))*nrowsPack + j] = params.pols[offsetsStages[stage] + l * nColsStages[stage] + stagePos + d]; + } + } + } + } + + for(uint64_t i = 0; i < setupCtx.starkInfo.customCommits.size(); ++i) { + for(uint64_t j = 0; j < setupCtx.starkInfo.customCommits[i].stageWidths[0]; ++j) { + if(!customCommitsUsed[i][j]) continue; + PolMap polInfo = setupCtx.starkInfo.customCommitsMap[i][j]; + uint64_t stage = setupCtx.starkInfo.nStages + 2 + i; + uint64_t stagePos = polInfo.stagePos; + for(uint64_t d = 0; d < polInfo.dim; ++d) { + for(uint64_t o = 0; o < nOpenings; ++o) { + for(uint64_t j = 0; j < nrowsPack; ++j) { + uint64_t l = (row + j + nextStrides[o]) % domainSize; + bufferT_[(nColsStagesAcc[ns*o + stage] + (stagePos + d))*nrowsPack + j] = params.customCommits[i][offsetsStages[stage] + l * nColsStages[stage] + stagePos + d]; + } } } } @@ -103,23 +145,23 @@ class ExpressionsPack : public ExpressionsCtx { if(dests[0].params[0].parserParams.expId == int64_t(setupCtx.starkInfo.cExpId)) { for(uint64_t d = 0; d < setupCtx.starkInfo.boundaries.size(); ++d) { for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT_[(nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + d + 1)*nrowsPack + j] = setupCtx.proverHelpers.zi[row + j + d*domainSize]; + bufferT_[(nColsStagesAcc[ns*nOpenings] + d + 1)*nrowsPack + j] = setupCtx.proverHelpers.zi[row + j + d*domainSize]; } } for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT_[(nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings])*nrowsPack + j] = setupCtx.proverHelpers.x_2ns[row + j]; + bufferT_[(nColsStagesAcc[ns*nOpenings])*nrowsPack + j] = setupCtx.proverHelpers.x_2ns[row + j]; } } else if(dests[0].params[0].parserParams.expId == int64_t(setupCtx.starkInfo.friExpId)) { for(uint64_t d = 0; d < setupCtx.starkInfo.openingPoints.size(); ++d) { for(uint64_t k = 0; k < FIELD_EXTENSION; ++k) { for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT_[(nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings] + d*FIELD_EXTENSION + k)*nrowsPack + j] = params.xDivXSub[(row + j + d*domainSize)*FIELD_EXTENSION + k]; + bufferT_[(nColsStagesAcc[ns*nOpenings] + d*FIELD_EXTENSION + k)*nrowsPack + j] = params.xDivXSub[(row + j + d*domainSize)*FIELD_EXTENSION + k]; } } } } else { for(uint64_t j = 0; j < nrowsPack; ++j) { - bufferT_[(nColsStagesAcc[(setupCtx.starkInfo.nStages + 2)*nOpenings])*nrowsPack + j] = setupCtx.proverHelpers.x[row + j]; + bufferT_[(nColsStagesAcc[ns*nOpenings])*nrowsPack + j] = setupCtx.proverHelpers.x[row + j]; } } } @@ -230,6 +272,7 @@ class ExpressionsPack : public ExpressionsCtx { void calculateExpressions(StepsParams& params, ParserArgs &parserArgs, std::vector dests, uint64_t domainSize) override { uint64_t nOpenings = setupCtx.starkInfo.openingPoints.size(); + uint64_t ns = 2 + setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size(); bool domainExtended = domainSize == uint64_t(1 << setupCtx.starkInfo.starkStruct.nBitsExt) ? true : false; uint64_t expId = dests[0].params[0].op == opType::tmp ? dests[0].params[0].parserParams.destDim : 0; @@ -302,9 +345,9 @@ class ExpressionsPack : public ExpressionsCtx { auto openingPointZero = std::find_if(setupCtx.starkInfo.openingPoints.begin(), setupCtx.starkInfo.openingPoints.end(), [](int p) { return p == 0; }); auto openingPointZeroIndex = std::distance(setupCtx.starkInfo.openingPoints.begin(), openingPointZero); - uint64_t buffPos = (setupCtx.starkInfo.nStages + 2)*openingPointZeroIndex + dests[j].params[k].stage; + uint64_t buffPos = ns*openingPointZeroIndex + dests[j].params[k].stage; uint64_t stagePos = dests[j].params[k].stagePos; - copyPolynomial(&destVals[j][k*FIELD_EXTENSION], dests[j].params[k].inverse, dests[j].params[k].dim, &bufferT_[nColsStagesAcc[buffPos] + stagePos]); + copyPolynomial(&destVals[j][k*FIELD_EXTENSION*nrowsPack], dests[j].params[k].inverse, dests[j].params[k].dim, &bufferT_[(nColsStagesAcc[buffPos] + stagePos)*nrowsPack]); continue; } else if(dests[j].params[k].op == opType::number) { uint64_t val = dests[j].params[k].inverse ? Goldilocks::inv(Goldilocks::fromU64(dests[j].params[k].value)).fe : dests[j].params[k].value; diff --git a/pil2-stark/src/starkpil/gen_recursive_proof.hpp b/pil2-stark/src/starkpil/gen_recursive_proof.hpp index f2915a54..e7540ccd 100644 --- a/pil2-stark/src/starkpil/gen_recursive_proof.hpp +++ b/pil2-stark/src/starkpil/gen_recursive_proof.hpp @@ -119,7 +119,7 @@ void *genRecursiveProof(SetupCtx& setupCtx, json& globalInfo, uint64_t airgroupI } Polinomial gprodTransposedPol; - setupCtx.starkInfo.getPolynomial(gprodTransposedPol, pAddress, true, gprodField->values[0].id, false); + setupCtx.starkInfo.getPolynomial(gprodTransposedPol, pAddress, "cm", setupCtx.starkInfo.cmPolsMap[gprodField->values[0].id], false); #pragma omp parallel for for(uint64_t j = 0; j < N; ++j) { std::memcpy(gprodTransposedPol[j], &gprod[j*FIELD_EXTENSION], FIELD_EXTENSION * sizeof(Goldilocks::Element)); @@ -268,7 +268,8 @@ void *genRecursiveProof(SetupCtx& setupCtx, json& globalInfo, uint64_t airgroupI starks.addTranscriptGL(transcriptPermutation, challenge, FIELD_EXTENSION); transcriptPermutation.getPermutations(friQueries, setupCtx.starkInfo.starkStruct.nQueries, setupCtx.starkInfo.starkStruct.steps[0].nBits); - FRI::proveQueries(friQueries, setupCtx.starkInfo.starkStruct.nQueries, proof, starks.treesGL, setupCtx.starkInfo.nStages + 2); + uint64_t nTrees = setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size() + 2; + FRI::proveQueries(friQueries, setupCtx.starkInfo.starkStruct.nQueries, proof, starks.treesGL, nTrees); for(uint64_t step = 1; step < setupCtx.starkInfo.starkStruct.steps.size(); ++step) { FRI::proveFRIQueries(friQueries, setupCtx.starkInfo.starkStruct.nQueries, step, setupCtx.starkInfo.starkStruct.steps[step].nBits, proof, starks.treesFRI[step - 1]); } diff --git a/pil2-stark/src/starkpil/hints.hpp b/pil2-stark/src/starkpil/hints.hpp index 108ab743..55a4eaa3 100644 --- a/pil2-stark/src/starkpil/hints.hpp +++ b/pil2-stark/src/starkpil/hints.hpp @@ -117,7 +117,7 @@ void printColById(SetupCtx& setupCtx, StepsParams ¶ms, bool committed, uint6 PolMap polInfo = committed ? setupCtx.starkInfo.cmPolsMap[polId] : setupCtx.starkInfo.constPolsMap[polId]; Goldilocks::Element *pols = committed ? params.pols : params.pConstPolsAddress; Polinomial p; - setupCtx.starkInfo.getPolynomial(p, pols, committed, polId, false); + setupCtx.starkInfo.getPolynomial(p, pols, committed ? "cm" : "const", polInfo, false); Polinomial pCol; Goldilocks::Element *pBuffCol = new Goldilocks::Element[polInfo.dim * N]; diff --git a/pil2-stark/src/starkpil/proof2zkinStark.cpp b/pil2-stark/src/starkpil/proof2zkinStark.cpp index 385b4875..ae9d778e 100644 --- a/pil2-stark/src/starkpil/proof2zkinStark.cpp +++ b/pil2-stark/src/starkpil/proof2zkinStark.cpp @@ -10,6 +10,7 @@ ordered_json proof2zkinStark(ordered_json &proof, StarkInfo &starkInfo) uint64_t friSteps = starkInfo.starkStruct.steps.size() - 1; uint64_t nQueries = starkInfo.starkStruct.nQueries; uint64_t nStages = starkInfo.nStages; + uint64_t nCustomCommits = starkInfo.customCommits.size(); string valsQ = "s0_vals" + to_string(nStages + 1); string siblingsQ = "s0_siblings" + to_string(nStages + 1); @@ -50,6 +51,13 @@ ordered_json proof2zkinStark(ordered_json &proof, StarkInfo &starkInfo) } } + for(uint64_t i = 0; i < nCustomCommits; ++i) { + if (proof["queries"]["polQueries"][0][i + nStages + 2][0].size()) { + zkinOut["s0_siblings_" + starkInfo.customCommits[i].name + "_0"] = ordered_json::array(); + zkinOut["s0_vals_" + starkInfo.customCommits[i].name + "_0"] = ordered_json::array(); + } + } + for (uint64_t i = 0; i < nQueries; i++) { for (uint64_t j = 0; j < nStages; ++j) { uint64_t stage = j + 1; @@ -59,6 +67,11 @@ ordered_json proof2zkinStark(ordered_json &proof, StarkInfo &starkInfo) } } + for (uint64_t j = 0; j < nCustomCommits; ++j) { + zkinOut["s0_vals_" + starkInfo.customCommits[j].name + "_0"][i] = proof["queries"]["polQueries"][i][j + nStages + 2][0]; + zkinOut["s0_siblings_" + starkInfo.customCommits[j].name + "_0"][i] = proof["queries"]["polQueries"][i][j + nStages + 2][1]; + } + zkinOut[valsQ][i] = proof["queries"]["polQueries"][i][nStages][0]; zkinOut[siblingsQ][i] = proof["queries"]["polQueries"][i][nStages][1]; diff --git a/pil2-stark/src/starkpil/proof_stark.hpp b/pil2-stark/src/starkpil/proof_stark.hpp index a6da057d..3c65b7b2 100644 --- a/pil2-stark/src/starkpil/proof_stark.hpp +++ b/pil2-stark/src/starkpil/proof_stark.hpp @@ -199,6 +199,7 @@ class Proofs { public: uint64_t nStages; + uint64_t nCustomCommits; uint64_t nFieldElements; uint64_t airId; uint64_t airgroupId; @@ -207,25 +208,31 @@ class Proofs std::vector> evals; std::vector> airgroupValues; std::vector> airValues; + std::vector customCommits; Proofs(StarkInfo &starkInfo) : fri(starkInfo), evals(starkInfo.evMap.size(), std::vector(FIELD_EXTENSION, Goldilocks::zero())), airgroupValues(starkInfo.airgroupValuesMap.size(), std::vector(FIELD_EXTENSION, Goldilocks::zero())), - airValues(starkInfo.airValuesMap.size(), std::vector(FIELD_EXTENSION, Goldilocks::zero())) + airValues(starkInfo.airValuesMap.size(), std::vector(FIELD_EXTENSION, Goldilocks::zero())), + customCommits(starkInfo.customCommits.size()) { nStages = starkInfo.nStages + 1; - roots = new ElementType*[nStages]; + nCustomCommits = starkInfo.customCommits.size(); + roots = new ElementType*[nStages + nCustomCommits]; nFieldElements = starkInfo.starkStruct.verificationHashType == "GL" ? HASH_SIZE : 1; airId = starkInfo.airId; airgroupId = starkInfo.airgroupId; - for(uint64_t i = 0; i < nStages; i++) + for(uint64_t i = 0; i < nStages + nCustomCommits; i++) { roots[i] = new ElementType[nFieldElements]; } + for(uint64_t i = 0; i < nCustomCommits; ++i) { + customCommits[i] = starkInfo.customCommits[i].name; + } }; ~Proofs() { - for (uint64_t i = 0; i < nStages; ++i) { + for (uint64_t i = 0; i < nStages + nCustomCommits; ++i) { delete[] roots[i]; } delete[] roots; diff --git a/pil2-stark/src/starkpil/stark_info.cpp b/pil2-stark/src/starkpil/stark_info.cpp index c8227c21..1c026623 100644 --- a/pil2-stark/src/starkpil/stark_info.cpp +++ b/pil2-stark/src/starkpil/stark_info.cpp @@ -60,6 +60,16 @@ void StarkInfo::load(json j) friExpId = j["friExpId"]; cExpId = j["cExpId"]; + + for(uint64_t i = 0; i < j["customCommits"].size(); i++) { + CustomCommits c; + c.name = j["customCommits"][i]["name"]; + for(uint64_t k = 0; k < j["customCommits"][i]["stageWidths"].size(); k++) { + c.stageWidths.push_back(j["customCommits"][i]["stageWidths"][k]); + } + customCommits.push_back(c); + } + for(uint64_t i = 0; i < j["openingPoints"].size(); i++) { openingPoints.push_back(j["openingPoints"][i]); } @@ -88,6 +98,11 @@ void StarkInfo::load(json j) { PolMap map; map.name = j["publicsMap"][i]["name"]; + if(j["publicsMap"][i].contains("lengths")) { + for (uint64_t l = 0; l < j["publicsMap"][i]["lengths"].size(); l++) { + map.lengths.push_back(j["publicsMap"][i]["lengths"][l]); + } + } publicsMap.push_back(map); } @@ -128,6 +143,32 @@ void StarkInfo::load(json j) cmPolsMap.push_back(map); } + for (uint64_t i = 0; i < j["customCommitsMap"].size(); i++) + { + vector custPolsMap(j["customCommitsMap"][i].size()); + for(uint64_t k = 0; k < j["customCommitsMap"][i].size(); ++k) { + PolMap map; + map.stage = j["customCommitsMap"][i][k]["stage"]; + map.name = j["customCommitsMap"][i][k]["name"]; + map.dim = j["customCommitsMap"][i][k]["dim"]; + map.stagePos = j["customCommitsMap"][i][k]["stagePos"]; + map.stageId = j["customCommitsMap"][i][k]["stageId"]; + map.commitId = i; + if(j["customCommitsMap"][i][k].contains("expId")) { + map.expId = j["customCommitsMap"][i][k]["expId"]; + } + if(j["customCommitsMap"][i].contains("lengths")) { + for (uint64_t l = 0; l < j["customCommitsMap"][i][k]["lengths"].size(); l++) { + map.lengths.push_back(j["customCommitsMap"][i][k]["lengths"][l]); + } + } + map.polsMapId = j["customCommitsMap"][i][k]["polsMapId"]; + custPolsMap[k] = map; + } + customCommitsMap.push_back(custPolsMap); + } + + for (uint64_t i = 0; i < j["constPolsMap"].size(); i++) { PolMap map; @@ -150,6 +191,9 @@ void StarkInfo::load(json j) { EvMap map; map.setType(j["evMap"][i]["type"]); + if(j["evMap"][i]["type"] == "custom") { + map.commitId = j["evMap"][i]["commitId"]; + } map.id = j["evMap"][i]["id"]; map.prime = j["evMap"][i]["prime"]; if(j["evMap"][i].contains("openingPos")) { @@ -184,6 +228,13 @@ void StarkInfo::setMapOffsets() { mapOffsets[std::make_pair("const", false)] = 0; mapOffsets[std::make_pair("const", true)] = 0; + // Set offsets for custom pols + for(uint64_t i = 0; i < customCommits.size(); ++i) { + mapOffsets[std::make_pair(customCommits[i].name + "0", false)] = 0; + mapOffsets[std::make_pair(customCommits[i].name + "0", true)] = N * mapSectionsN[customCommits[i].name + "0"]; + mapTotalNcustomCommits[customCommits[i].name] = (N + NExtended) * mapSectionsN[customCommits[i].name + "0"]; + } + mapTotalN = 0; // Set offsets for all stages in the extended field (cm1, cm2, ..., cmN) @@ -229,15 +280,14 @@ void StarkInfo::setMapOffsets() { if(offsetPolsEvals > mapTotalN) mapTotalN = offsetPolsEvals; } -void StarkInfo::getPolynomial(Polinomial &pol, Goldilocks::Element *pAddress, bool committed, uint64_t idPol, bool domainExtended) { - PolMap polInfo = committed ? cmPolsMap[idPol] : constPolsMap[idPol]; +void StarkInfo::getPolynomial(Polinomial &pol, Goldilocks::Element *pAddress, string type, PolMap& polInfo, bool domainExtended) { uint64_t deg = domainExtended ? 1 << starkStruct.nBitsExt : 1 << starkStruct.nBits; uint64_t dim = polInfo.dim; - std::string stage = committed ? "cm" + to_string(polInfo.stage) : "const"; + std::string stage = type == "cm" ? "cm" + to_string(polInfo.stage) : type == "custom" ? customCommits[polInfo.commitId].name + "0" : "const"; uint64_t nCols = mapSectionsN[stage]; uint64_t offset = mapOffsets[std::make_pair(stage, domainExtended)]; offset += polInfo.stagePos; - pol = Polinomial(&pAddress[offset], deg, dim, nCols, std::to_string(idPol)); + pol = Polinomial(&pAddress[offset], deg, dim, nCols); } diff --git a/pil2-stark/src/starkpil/stark_info.hpp b/pil2-stark/src/starkpil/stark_info.hpp index b01d4557..b714dc0d 100644 --- a/pil2-stark/src/starkpil/stark_info.hpp +++ b/pil2-stark/src/starkpil/stark_info.hpp @@ -32,6 +32,13 @@ typedef enum } opType; +class CustomCommits +{ +public: + std::string name; + vector stageWidths; +}; + class Boundary { public: @@ -71,6 +78,7 @@ class PolMap bool imPol; uint64_t stagePos; uint64_t stageId; + uint64_t commitId; uint64_t expId; uint64_t polsMapId; }; @@ -82,17 +90,20 @@ class EvMap { cm = 0, _const = 1, + custom = 2, } eType; eType type; uint64_t id; int64_t prime; + uint64_t commitId; uint64_t openingPos; void setType (string s) { if (s == "cm") type = cm; else if (s == "const") type = _const; + else if (s == "custom") type = custom; else { zklog.error("EvMap::setType() found invalid type: " + s); @@ -115,12 +126,15 @@ class StarkInfo uint64_t nStages; + vector customCommits; + vector cmPolsMap; vector constPolsMap; vector challengesMap; vector airgroupValuesMap; vector airValuesMap; vector publicsMap; + vector> customCommitsMap; vector evMap; @@ -139,6 +153,8 @@ class StarkInfo std::map, uint64_t> mapOffsets; uint64_t mapTotalN; + + std::map mapTotalNcustomCommits; /* Constructor */ StarkInfo(string file); @@ -149,7 +165,7 @@ class StarkInfo void setMapOffsets(); /* Returns a polynomial specified by its ID */ - void getPolynomial(Polinomial &pol, Goldilocks::Element *pAddress, bool committed, uint64_t idPol, bool domainExtended); + void getPolynomial(Polinomial &pol, Goldilocks::Element *pAddress, string type, PolMap& polInfo, bool domainExtended); }; #endif \ No newline at end of file diff --git a/pil2-stark/src/starkpil/starks.cpp b/pil2-stark/src/starkpil/starks.cpp index b039948b..6ac77467 100644 --- a/pil2-stark/src/starkpil/starks.cpp +++ b/pil2-stark/src/starkpil/starks.cpp @@ -3,6 +3,30 @@ #include "zklog.hpp" #include "exit_process.hpp" +template +void Starks::extendAndMerkelizeCustomCommit(uint64_t commitId, uint64_t step, Goldilocks::Element *buffer, FRIProof &proof, Goldilocks::Element *pBuffHelper) +{ + uint64_t N = 1 << setupCtx.starkInfo.starkStruct.nBits; + uint64_t NExtended = 1 << setupCtx.starkInfo.starkStruct.nBitsExt; + + std::string section = setupCtx.starkInfo.customCommits[commitId].name + to_string(step); + uint64_t nCols = setupCtx.starkInfo.mapSectionsN[section]; + Goldilocks::Element *pBuff = &buffer[setupCtx.starkInfo.mapOffsets[make_pair(section, false)]]; + Goldilocks::Element *pBuffExtended = &buffer[setupCtx.starkInfo.mapOffsets[make_pair(section, true)]]; + + NTT_Goldilocks ntt(N); + if(pBuffHelper != nullptr) { + ntt.extendPol(pBuffExtended, pBuff, NExtended, N, nCols, pBuffHelper); + } else { + ntt.extendPol(pBuffExtended, pBuff, NExtended, N, nCols); + } + + uint64_t pos = setupCtx.starkInfo.nStages + 2 + commitId; + treesGL[pos]->setSource(pBuffExtended); + treesGL[pos]->merkelize(); + treesGL[pos]->getRoot(&proof.proof.roots[pos - 1][0]); +} + template void Starks::extendAndMerkelize(uint64_t step, Goldilocks::Element *buffer, FRIProof &proof, Goldilocks::Element *pBuffHelper) { @@ -197,9 +221,10 @@ void Starks::evmap(StepsParams& params, Goldilocks::Element *LEv) for (uint64_t i = 0; i < size_eval; i++) { EvMap ev = setupCtx.starkInfo.evMap[i]; - bool committed = ev.type == EvMap::eType::cm ? true : false; - Goldilocks::Element *pols = committed ? params.pols : ¶ms.pConstPolsExtendedTreeAddress[2]; - setupCtx.starkInfo.getPolynomial(ordPols[i], pols, committed, ev.id, true); + string type = ev.type == EvMap::eType::cm ? "cm" : ev.type == EvMap::eType::custom ? "custom" : "fixed"; + Goldilocks::Element *pols = type == "cm" ? params.pols : type == "custom" ? params.customCommits[ev.commitId] : ¶ms.pConstPolsExtendedTreeAddress[2]; + PolMap polInfo = type == "cm" ? setupCtx.starkInfo.cmPolsMap[ev.id] : type == "custom" ? setupCtx.starkInfo.customCommitsMap[ev.commitId][ev.id] : setupCtx.starkInfo.constPolsMap[ev.id]; + setupCtx.starkInfo.getPolynomial(ordPols[i], pols, type, polInfo, true); } #pragma omp parallel @@ -274,6 +299,12 @@ void Starks::ffi_treesGL_get_root(uint64_t index, ElementType *dst) treesGL[index]->getRoot(dst); } +template +void Starks::ffi_treesGL_set_root(uint64_t index, FRIProof &proof) +{ + treesGL[index]->getRoot(&proof.proof.roots[index][0]); +} + template void Starks::calculateImPolsExpressions(uint64_t step, StepsParams ¶ms) { std::vector dests; diff --git a/pil2-stark/src/starkpil/starks.hpp b/pil2-stark/src/starkpil/starks.hpp index 6680a694..54232c88 100644 --- a/pil2-stark/src/starkpil/starks.hpp +++ b/pil2-stark/src/starkpil/starks.hpp @@ -34,7 +34,7 @@ class Starks public: Starks(SetupCtx& setupCtx_, Goldilocks::Element *pConstPolsExtendedTreeAddress) : setupCtx(setupCtx_) { - treesGL = new MerkleTreeType*[setupCtx.starkInfo.nStages + 2]; + treesGL = new MerkleTreeType*[setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size() + 2]; treesGL[setupCtx.starkInfo.nStages + 1] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, pConstPolsExtendedTreeAddress); for (uint64_t i = 0; i < setupCtx.starkInfo.nStages + 1; i++) { @@ -43,6 +43,13 @@ class Starks treesGL[i] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, 1 << setupCtx.starkInfo.starkStruct.nBitsExt, nCols, NULL, false); } + + + for(uint64_t i = 0; i < setupCtx.starkInfo.customCommits.size(); i++) { + uint64_t nCols = setupCtx.starkInfo.mapSectionsN[setupCtx.starkInfo.customCommits[i].name + "0"]; + treesGL[setupCtx.starkInfo.nStages + 2 + i] = new MerkleTreeType(setupCtx.starkInfo.starkStruct.merkleTreeArity, setupCtx.starkInfo.starkStruct.merkleTreeCustom, 1 << setupCtx.starkInfo.starkStruct.nBitsExt, nCols, NULL, false); + } + treesFRI = new MerkleTreeType*[setupCtx.starkInfo.starkStruct.steps.size() - 1]; for(uint64_t step = 0; step < setupCtx.starkInfo.starkStruct.steps.size() - 1; ++step) { uint64_t nGroups = 1 << setupCtx.starkInfo.starkStruct.steps[step + 1].nBits; @@ -53,7 +60,7 @@ class Starks }; ~Starks() { - for (uint i = 0; i < setupCtx.starkInfo.nStages + 2; i++) + for (uint i = 0; i < setupCtx.starkInfo.nStages + setupCtx.starkInfo.customCommits.size() + 2; i++) { delete treesGL[i]; } @@ -64,9 +71,9 @@ class Starks delete treesFRI[i]; } delete[] treesFRI; - }; + void extendAndMerkelizeCustomCommit(uint64_t commitId, uint64_t step, Goldilocks::Element *buffer, FRIProof &proof, Goldilocks::Element *pBuffHelper); void extendAndMerkelize(uint64_t step, Goldilocks::Element *buffer, FRIProof &proof, Goldilocks::Element* pBuffHelper = nullptr); void commitStage(uint64_t step, Goldilocks::Element *buffer, FRIProof &proof, Goldilocks::Element* pBuffHelper = nullptr); @@ -89,6 +96,7 @@ class Starks // Following function are created to be used by the ffi interface void ffi_treesGL_get_root(uint64_t index, ElementType *dst); + void ffi_treesGL_set_root(uint64_t index, FRIProof &proof); void evmap(StepsParams& params, Goldilocks::Element *LEv); }; diff --git a/pil2-stark/src/starkpil/steps.hpp b/pil2-stark/src/starkpil/steps.hpp index c91a0756..4721ba98 100644 --- a/pil2-stark/src/starkpil/steps.hpp +++ b/pil2-stark/src/starkpil/steps.hpp @@ -14,6 +14,7 @@ struct StepsParams Goldilocks::Element *xDivXSub; Goldilocks::Element *pConstPolsAddress; Goldilocks::Element *pConstPolsExtendedTreeAddress; + Goldilocks::Element *customCommits[10]; }; #endif \ No newline at end of file diff --git a/pilout/src/pilout.proto b/pilout/src/pilout.proto index 19cb86b9..042b86a2 100644 --- a/pilout/src/pilout.proto +++ b/pilout/src/pilout.proto @@ -87,8 +87,7 @@ message GlobalOperand { message AirGroupValue { uint32 airGroupId = 1; - uint32 stage = 2; - uint32 idx = 3; + uint32 idx = 2; } message PublicValue { @@ -128,8 +127,19 @@ message Air { repeated uint32 stageWidths = 5; // stage widths excluding stage 0 (fixed columns) repeated Expression expressions = 6; repeated Constraint constraints = 7; - repeated uint32 stageAirValues = 8; // stage airvalues + repeated AirValue airValues = 8; // stage airvalues bool aggregable = 9; + repeated CustomCommit customCommits = 10; +} + +message CustomCommit { + optional string name = 1; + repeated GlobalOperand.PublicValue publicValues = 2; + repeated uint32 stageWidths = 3; // stage widths including stage 0 !! +} + +message AirValue { + uint32 stage = 1; } message PeriodicCol { @@ -186,13 +196,11 @@ message Operand { } message AirGroupValue { - uint32 stage = 1; - uint32 idx = 2; + uint32 idx = 1; } message AirValue { - uint32 stage = 1; - uint32 idx = 2; + uint32 idx = 1; } message PublicValue { @@ -215,6 +223,13 @@ message Operand { sint32 rowOffset = 3; } + message CustomCol { + uint32 commitId = 1; + uint32 stage = 2; + uint32 colIdx = 3; // absolute idx relative to the stage + sint32 rowOffset = 4; + } + message Expression { uint32 idx = 1; } @@ -230,6 +245,7 @@ message Operand { WitnessCol witnessCol = 8; Expression expression = 9; AirValue airValue = 10; + CustomCol customCol = 11; } } @@ -269,11 +285,12 @@ enum SymbolType { PERIODIC_COL = 2; WITNESS_COL = 3; PROOF_VALUE = 4; - AIR_GROUP_VALUE = 5; + AIR_GROUP_VALUE = 5; PUBLIC_VALUE = 6; PUBLIC_TABLE = 7; CHALLENGE = 8; - AIR_VALUE = 9; + AIR_VALUE = 9; + CUSTOM_COL = 10; } message Symbol { @@ -285,7 +302,8 @@ message Symbol { optional uint32 stage = 6; uint32 dim = 7; repeated uint32 lengths = 8; - optional string debugLine = 9; + optional uint32 commitId = 9; + optional string debugLine = 10; } // ================ Hints ================ diff --git a/proofman/src/proofman.rs b/proofman/src/proofman.rs index ef4b09b8..56f3b9d7 100644 --- a/proofman/src/proofman.rs +++ b/proofman/src/proofman.rs @@ -81,7 +81,7 @@ impl ProofMan { Self::print_summary(pctx.clone()); } - Self::initialize_setup(setups.clone(), pctx.clone(), ectx.clone(), options.aggregation); + Self::initialize_fixed_pols(setups.clone(), pctx.clone(), ectx.clone(), options.aggregation); let mut provers: Vec>> = Vec::new(); Self::initialize_provers(sctx.clone(), &mut provers, pctx.clone(), ectx.clone()); @@ -92,20 +92,37 @@ impl ProofMan { let mut transcript: FFITranscript = provers[0].new_transcript(); + Self::check_stage(0, &mut provers, pctx.clone()); + for prover in provers.iter_mut() { + prover.commit_stage(0, pctx.clone()); + } + // Commit stages let num_commit_stages = pctx.global_info.n_challenges.len() as u32; for stage in 1..=num_commit_stages { Self::get_challenges(stage, &mut provers, pctx.clone(), &transcript); if stage != 1 { + timer_start_debug!(CALCULATING_WITNESS); + info!("{}: Calculating witness stage {}", Self::MY_NAME, stage); witness_lib.calculate_witness(stage, pctx.clone(), ectx.clone(), sctx.clone()); + timer_stop_and_log_debug!(CALCULATING_WITNESS); } Self::calculate_stage(stage, &mut provers, sctx.clone(), pctx.clone()); + Self::check_stage(stage, &mut provers, pctx.clone()); if !options.verify_constraints { Self::commit_stage(stage, &mut provers, pctx.clone()); } + let publics_set = pctx.public_inputs.inputs_set.read().unwrap(); + for i in 0..pctx.global_info.n_publics { + let public = pctx.global_info.publics_map.as_ref().expect("REASON").get(i).unwrap(); + if !publics_set[i] { + panic!("Not all publics are set: Public {} is not calculated", public.name); + } + } + if !options.verify_constraints || stage < num_commit_stages { Self::calculate_challenges( stage, @@ -308,7 +325,7 @@ impl ProofMan { timer_stop_and_log_debug!(INITIALIZE_PROVERS); } - fn initialize_setup( + fn initialize_fixed_pols( setups: Arc>, pctx: Arc>, _ectx: Arc>, @@ -333,14 +350,17 @@ impl ProofMan { timer_stop_and_log_debug!(INITIALIZE_CONST_POLS); if aggregation { - timer_start_debug!(INITIALIZE_CONST_POLS_COMPRESSOR); - let mut const_pols_calculated_compressor: HashMap<(usize, usize), bool> = HashMap::new(); + info!("{}: Initializing setup fixed pols aggregation", Self::MY_NAME); let sctx_compressor = setups.sctx_compressor.as_ref().unwrap().clone(); let sctx_recursive1 = setups.sctx_recursive1.as_ref().unwrap().clone(); let sctx_recursive2 = setups.sctx_recursive2.as_ref().unwrap().clone(); let sctx_final = setups.sctx_final.as_ref().unwrap().clone(); + timer_start_debug!(INITIALIZE_CONST_POLS_COMPRESSOR); + info!("{}: ··· Initializing setup fixed pols compressor", Self::MY_NAME); + let mut const_pols_calculated_compressor: HashMap<(usize, usize), bool> = HashMap::new(); + for air_instance in pctx.air_instance_repo.air_instances.read().unwrap().iter() { let (airgroup_id, air_id) = (air_instance.airgroup_id, air_instance.air_id); if pctx.global_info.get_air_has_compressor(airgroup_id, air_id) @@ -355,6 +375,7 @@ impl ProofMan { timer_stop_and_log_debug!(INITIALIZE_CONST_POLS_COMPRESSOR); timer_start_debug!(INITIALIZE_CONST_POLS_RECURSIVE1); + info!("{}: ··· Initializing setup fixed pols recursive1", Self::MY_NAME); let mut const_pols_calculated_recursive1: HashMap<(usize, usize), bool> = HashMap::new(); for air_instance in pctx.air_instance_repo.air_instances.read().unwrap().iter() { let (airgroup_id, air_id) = (air_instance.airgroup_id, air_instance.air_id); @@ -368,6 +389,7 @@ impl ProofMan { timer_stop_and_log_debug!(INITIALIZE_CONST_POLS_RECURSIVE1); timer_start_debug!(INITIALIZE_CONST_POLS_RECURSIVE2); + info!("{}: ··· Initializing setup fixed pols recursive2", Self::MY_NAME); let n_airgroups = pctx.global_info.air_groups.len(); for airgroup in 0..n_airgroups { let setup = sctx_recursive2.get_setup(airgroup, 0); @@ -377,6 +399,7 @@ impl ProofMan { timer_stop_and_log_debug!(INITIALIZE_CONST_POLS_RECURSIVE2); timer_start_debug!(INITIALIZE_CONST_POLS_FINAL); + info!("{}: ··· Initializing setup fixed pols final", Self::MY_NAME); let setup = sctx_final.get_setup(0, 0); setup.load_const_pols(&pctx.global_info, &ProofType::Final); setup.load_const_pols_tree(&pctx.global_info, &ProofType::Final, false); @@ -408,6 +431,13 @@ impl ProofMan { } } + pub fn check_stage(stage: u32, provers: &mut [Box>], proof_ctx: Arc>) { + log::debug!("{}: Checking stage can be calculated", Self::MY_NAME); + for prover in provers.iter_mut() { + prover.check_stage(stage, proof_ctx.clone()); + } + } + pub fn commit_stage(stage: u32, provers: &mut [Box>], proof_ctx: Arc>) { if stage as usize == proof_ctx.global_info.n_challenges.len() + 1 { info!("{}: Committing stage Q", Self::MY_NAME); diff --git a/proofman/src/recursion.rs b/proofman/src/recursion.rs index 747f933a..82c81076 100644 --- a/proofman/src/recursion.rs +++ b/proofman/src/recursion.rs @@ -110,8 +110,8 @@ pub fn generate_recursion_proof( false => String::from(""), }; - let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let p_prove = gen_recursive_proof_c( p_setup, @@ -244,10 +244,8 @@ pub fn generate_recursion_proof( MY_NAME, format!("··· Generating recursive2 proof for instances of {}", air_instance_name) ); - let const_pols_ptr = - (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = - (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let zkin = gen_recursive_proof_c( p_setup, @@ -329,8 +327,8 @@ pub fn generate_recursion_proof( log::info!("{}: ··· Generating final proof", MY_NAME); timer_start_trace!(GENERATE_PROOF); // prove - let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let _p_prove = gen_recursive_proof_c( p_setup, p_address, diff --git a/provers/stark/src/stark_info.rs b/provers/stark/src/stark_info.rs index fb120061..0665a67b 100644 --- a/provers/stark/src/stark_info.rs +++ b/provers/stark/src/stark_info.rs @@ -91,6 +91,23 @@ pub struct PolMap { pub stage_pos: u64, #[serde(default, rename = "stageId")] pub stage_id: u64, + #[serde(default)] + pub lengths: Vec, +} + +#[allow(dead_code)] +#[derive(Deserialize, Clone, Copy)] +pub struct PublicValues { + pub idx: u64, +} + +#[derive(Deserialize)] +pub struct CustomCommits { + pub name: String, + #[serde(default, rename = "stageWidths")] + pub stage_widths: Vec, + #[serde(rename = "publicValues")] + pub public_values: Vec, } #[allow(dead_code)] @@ -100,8 +117,8 @@ enum EvMapEType { Cm, #[serde(rename = "const")] Const, - #[serde(rename = "q")] - Q, + #[serde(rename = "custom")] + Custom, } fn deserialize_bool_from_int<'de, D>(deserializer: D) -> Result @@ -142,6 +159,10 @@ pub struct StarkInfo { #[serde(rename = "cmPolsMap")] pub cm_pols_map: Option>, + #[serde(rename = "publicsMap")] + pub publics_map: Option>, + #[serde(rename = "customCommitsMap")] + pub custom_commits_map: Vec>>, #[serde(rename = "challengesMap")] pub challenges_map: Option>, #[serde(rename = "airgroupValuesMap")] @@ -151,6 +172,9 @@ pub struct StarkInfo { #[serde(rename = "evMap")] pub ev_map: Vec, + #[serde(rename = "customCommits")] + pub custom_commits: Vec, + #[serde(default = "default_opening_points", rename = "openingPoints")] pub opening_points: Vec, diff --git a/provers/stark/src/stark_prover.rs b/provers/stark/src/stark_prover.rs index 8626448d..31813b12 100644 --- a/provers/stark/src/stark_prover.rs +++ b/provers/stark/src/stark_prover.rs @@ -65,7 +65,7 @@ impl StarkProver { let setup = sctx.get_setup(airgroup_id, air_id); - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let p_stark = starks_new_c((&setup.p_setup).into(), const_tree_ptr); @@ -117,6 +117,16 @@ impl Prover for StarkProver { air_instance.set_commit_calculated(i); } + for commit_id in 0..self.stark_info.custom_commits.len() { + if !air_instance.custom_commits[commit_id].is_empty() { + for idx in 0..self.stark_info.custom_commits_map[commit_id].as_ref().unwrap().len() { + if self.stark_info.custom_commits_map[commit_id].as_ref().unwrap()[idx].stage <= 1 { + air_instance.set_custom_commit_calculated(commit_id, idx); + } + } + } + } + self.initialized = true; } @@ -145,8 +155,8 @@ impl Prover for StarkProver { let public_inputs_guard = proof_ctx.public_inputs.inputs.read().unwrap(); let challenges_guard = proof_ctx.challenges.challenges.read().unwrap(); - let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, @@ -158,6 +168,7 @@ impl Prover for StarkProver { xdivxsub: std::ptr::null_mut(), p_const_pols: const_pols_ptr, p_const_tree: const_tree_ptr, + custom_commits: air_instance.get_custom_commits_ptr(), }; let raw_ptr = verify_constraints_c((&setup.p_setup).into(), (&steps_params).into()); @@ -179,8 +190,8 @@ impl Prover for StarkProver { let public_inputs_guard = proof_ctx.public_inputs.inputs.read().unwrap(); let challenges_guard = proof_ctx.challenges.challenges.read().unwrap(); - let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let steps_params = StepsParams { buffer: air_instance.get_buffer_ptr() as *mut c_void, @@ -192,30 +203,40 @@ impl Prover for StarkProver { xdivxsub: std::ptr::null_mut(), p_const_pols: const_pols_ptr, p_const_tree: const_tree_ptr, + custom_commits: air_instance.get_custom_commits_ptr(), }; if stage_id as usize <= proof_ctx.global_info.n_challenges.len() { - let air_name = &proof_ctx.global_info.airs[self.airgroup_id][self.air_id].name; - debug!( - "{}: ··· Computing intermediate polynomials of instance {} of {}", - Self::MY_NAME, - self.instance_id, - air_name - ); - for i in 0..n_commits { - let cm_pol = self.stark_info.cm_pols_map.as_ref().expect("REASON").get(i).unwrap(); - if (cm_pol.stage < stage_id as u64 || cm_pol.stage == stage_id as u64 && !cm_pol.im_pol) - && !air_instance.commits_calculated.contains_key(&i) - { - panic!("Intermediate polynomials for stage {} cannot be calculated: Witness column {} is not calculated", stage_id, cm_pol.name); + if self + .stark_info + .cm_pols_map + .as_ref() + .expect("REASON") + .iter() + .any(|cm_pol| cm_pol.stage == stage_id as u64 && cm_pol.im_pol) + { + let air_name = &proof_ctx.global_info.airs[self.airgroup_id][self.air_id].name; + debug!( + "{}: ··· Computing intermediate polynomials of instance {} of {}", + Self::MY_NAME, + self.instance_id, + air_name + ); + for i in 0..n_commits { + let cm_pol = self.stark_info.cm_pols_map.as_ref().expect("REASON").get(i).unwrap(); + if (cm_pol.stage < stage_id as u64 || cm_pol.stage == stage_id as u64 && !cm_pol.im_pol) + && !air_instance.commits_calculated.contains_key(&i) + { + panic!("Intermediate polynomials for stage {} cannot be calculated: Witness column {} is not calculated", stage_id, cm_pol.name); + } } - } - calculate_impols_expressions_c(self.p_stark, stage_id as u64, (&steps_params).into()); - for i in 0..n_commits { - let cm_pol = self.stark_info.cm_pols_map.as_ref().expect("REASON").get(i).unwrap(); - if cm_pol.stage == stage_id as u64 && cm_pol.im_pol { - air_instance.set_commit_calculated(i); + calculate_impols_expressions_c(self.p_stark, stage_id as u64, (&steps_params).into()); + for i in 0..n_commits { + let cm_pol = self.stark_info.cm_pols_map.as_ref().expect("REASON").get(i).unwrap(); + if cm_pol.stage == stage_id as u64 && cm_pol.im_pol { + air_instance.set_commit_calculated(i); + } } } @@ -241,6 +262,10 @@ impl Prover for StarkProver { } } } + } + + fn check_stage(&self, stage_id: u32, proof_ctx: Arc>) { + let air_instance = &mut proof_ctx.air_instance_repo.air_instances.write().unwrap()[self.prover_idx]; let n_commits = self.stark_info.cm_pols_map.as_ref().expect("REASON").len(); for i in 0..n_commits { @@ -269,32 +294,88 @@ impl Prover for StarkProver { panic!("Stage {} cannot be committed: Airvalue {} is not calculated", stage_id, air_value.name); } } + + let n_custom_commits = self.stark_info.custom_commits_map.len(); + for i in 0..n_custom_commits { + let n_custom_commits = self.stark_info.custom_commits_map[i].as_ref().expect("REASON").len(); + for j in 0..n_custom_commits { + let custom_pol = self.stark_info.custom_commits_map[i].as_ref().expect("REASON").get(j).unwrap(); + if stage_id as u64 == custom_pol.stage && !air_instance.custom_commits_calculated[i].contains_key(&j) { + panic!( + "Stage {} cannot be committed: Custom commit of {} that is {} is not calculated", + stage_id, self.stark_info.custom_commits[i].name, custom_pol.name + ); + } + } + } } fn commit_stage(&mut self, stage_id: u32, proof_ctx: Arc>) -> ProverStatus { let air_instance = &mut proof_ctx.air_instance_repo.air_instances.write().unwrap()[self.prover_idx]; - let buffer = air_instance.get_buffer_ptr() as *mut c_void; - let p_stark: *mut std::ffi::c_void = self.p_stark; - - let air_name = &proof_ctx.global_info.airs[self.airgroup_id][self.air_id].name; - debug!( - "{}: ··· Committing prover {}: instance {} of {}", - Self::MY_NAME, - self.prover_idx, - self.instance_id, - air_name - ); - - timer_start_trace!(STARK_COMMIT_STAGE_, stage_id); + let p_stark = self.p_stark; let p_proof = self.p_proof; - let element_type = if type_name::() == type_name::() { 1 } else { 0 }; let buff_helper_guard = proof_ctx.buff_helper.buff_helper.read().unwrap(); let buff_helper = (*buff_helper_guard).as_ptr() as *mut c_void; - commit_stage_c(p_stark, element_type, stage_id as u64, buffer, p_proof, buff_helper); - timer_stop_and_log_trace!(STARK_COMMIT_STAGE_, stage_id); + let air_name = &proof_ctx.global_info.airs[self.airgroup_id][self.air_id].name; + if stage_id >= 1 { + debug!( + "{}: ··· Committing prover {}: instance {} of {}", + Self::MY_NAME, + self.prover_idx, + self.instance_id, + air_name + ); + + timer_start_trace!(STARK_COMMIT_STAGE_, stage_id); + + let buffer = air_instance.get_buffer_ptr() as *mut c_void; + let element_type = if type_name::() == type_name::() { 1 } else { 0 }; + + commit_stage_c(p_stark, element_type, stage_id as u64, buffer, p_proof, buff_helper); + timer_stop_and_log_trace!(STARK_COMMIT_STAGE_, stage_id); + } else { + let n_custom_commits = self.stark_info.custom_commits.len(); + for commit_id in 0..n_custom_commits { + let custom_commits_stage = self.stark_info.custom_commits_map[commit_id] + .as_ref() + .expect("REASON") + .iter() + .any(|custom_commit| custom_commit.stage == stage_id as u64); + + if custom_commits_stage { + extend_and_merkelize_custom_commit_c( + p_stark, + commit_id as u64, + stage_id as u64, + air_instance.custom_commits[commit_id].as_ptr() as *mut c_void, + p_proof, + buff_helper, + ); + } + + let mut value = vec![Goldilocks::zero(); self.n_field_elements]; + treesGL_get_root_c( + p_stark, + (self.stark_info.n_stages + 2 + commit_id as u32) as u64, + value.as_mut_ptr() as *mut c_void, + ); + if !self.stark_info.custom_commits[commit_id].public_values.is_empty() { + assert!( + self.n_field_elements == self.stark_info.custom_commits[commit_id].public_values.len(), + "Invalid public values size" + ); + for (idx, val) in value.iter().enumerate() { + proof_ctx.set_public_value( + val.as_canonical_u64(), + self.stark_info.custom_commits[commit_id].public_values[idx].idx, + ); + } + } + } + } if stage_id <= self.num_stages() + 1 { ProverStatus::CommitStage @@ -665,7 +746,7 @@ impl StarkProver { let buff_helper_guard = proof_ctx.buff_helper.buff_helper.read().unwrap(); let buff_helper = (*buff_helper_guard).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let steps_params = StepsParams { buffer, @@ -677,6 +758,7 @@ impl StarkProver { xdivxsub: std::ptr::null_mut(), p_const_pols: std::ptr::null_mut(), p_const_tree: const_tree_ptr, + custom_commits: air_instance.get_custom_commits_ptr(), }; compute_evals_c(p_stark, (&steps_params).into(), buff_helper, p_proof); @@ -693,8 +775,8 @@ impl StarkProver { let challenges_guard = proof_ctx.challenges.challenges.read().unwrap(); let buff_helper_guard = proof_ctx.buff_helper.buff_helper.read().unwrap(); - let const_pols_ptr = (*setup.const_pols.const_pols.read().unwrap()).as_ptr() as *mut c_void; - let const_tree_ptr = (*setup.const_tree.const_pols.read().unwrap()).as_ptr() as *mut c_void; + let const_pols_ptr = (*setup.const_pols.values.read().unwrap()).as_ptr() as *mut c_void; + let const_tree_ptr = (*setup.const_tree.values.read().unwrap()).as_ptr() as *mut c_void; let p_stark = self.p_stark; @@ -708,6 +790,7 @@ impl StarkProver { xdivxsub: (*buff_helper_guard).as_ptr() as *mut c_void, p_const_pols: const_pols_ptr, p_const_tree: const_tree_ptr, + custom_commits: air_instance.get_custom_commits_ptr(), }; calculate_fri_polynomial_c(p_stark, (&steps_params).into()); @@ -796,7 +879,8 @@ impl StarkProver { let fri_pol = get_fri_pol_c(self.p_stark_info, buffer); - compute_queries_c(p_stark, p_proof, fri_queries.as_mut_ptr(), n_queries, (self.num_stages() + 2) as u64); + let n_trees = self.num_stages() + 2 + self.stark_info.custom_commits.len() as u32; + compute_queries_c(p_stark, p_proof, fri_queries.as_mut_ptr(), n_queries, n_trees as u64); for (step, _) in steps.iter().enumerate().take(self.stark_info.stark_struct.steps.len()).skip(1) { compute_fri_queries_c( self.p_stark, @@ -838,4 +922,18 @@ impl BufferAllocator for StarkBufferAllocator { let p_stark_info = ps.p_setup.p_stark_info; Ok((get_map_totaln_c(p_stark_info), vec![get_map_offsets_c(p_stark_info, "cm1", false)])) } + + fn get_buffer_info_custom_commit( + &self, + sctx: &SetupCtx, + airgroup_id: usize, + air_id: usize, + name: &str, + ) -> Result<(u64, Vec, u64), Box> { + let ps = sctx.get_setup(airgroup_id, air_id); + + let p_stark_info = ps.p_setup.p_stark_info; + let commit_id = get_custom_commit_id_c(p_stark_info, name); + Ok((get_map_totaln_custom_commits_c(p_stark_info, commit_id), vec![0], commit_id)) + } } diff --git a/provers/starks-lib-c/bindings_starks.rs b/provers/starks-lib-c/bindings_starks.rs index 4aa01c86..a47c7412 100644 --- a/provers/starks-lib-c/bindings_starks.rs +++ b/provers/starks-lib-c/bindings_starks.rs @@ -91,6 +91,20 @@ extern "C" { #[link_name = "\u{1}_Z15get_map_total_nPv"] pub fn get_map_total_n(pStarkInfo: *mut ::std::os::raw::c_void) -> u64; } +extern "C" { + #[link_name = "\u{1}_Z20get_custom_commit_idPvPc"] + pub fn get_custom_commit_id( + pStarkInfo: *mut ::std::os::raw::c_void, + name: *mut ::std::os::raw::c_char, + ) -> u64; +} +extern "C" { + #[link_name = "\u{1}_Z30get_map_total_n_custom_commitsPvm"] + pub fn get_map_total_n_custom_commits( + pStarkInfo: *mut ::std::os::raw::c_void, + commit_id: u64, + ) -> u64; +} extern "C" { #[link_name = "\u{1}_Z15get_map_offsetsPvPcb"] pub fn get_map_offsets( @@ -111,6 +125,10 @@ extern "C" { #[link_name = "\u{1}_Z11get_n_evalsPv"] pub fn get_n_evals(pStarkInfo: *mut ::std::os::raw::c_void) -> u64; } +extern "C" { + #[link_name = "\u{1}_Z20get_n_custom_commitsPv"] + pub fn get_n_custom_commits(pStarkInfo: *mut ::std::os::raw::c_void) -> u64; +} extern "C" { #[link_name = "\u{1}_Z23get_airvalue_id_by_namePvPc"] pub fn get_airvalue_id_by_name( @@ -260,6 +278,14 @@ extern "C" { root: *mut ::std::os::raw::c_void, ); } +extern "C" { + #[link_name = "\u{1}_Z16treesGL_set_rootPvmS_"] + pub fn treesGL_set_root( + pStarks: *mut ::std::os::raw::c_void, + index: u64, + pProof: *mut ::std::os::raw::c_void, + ); +} extern "C" { #[link_name = "\u{1}_Z18calculate_xdivxsubPvS_S_"] pub fn calculate_xdivxsub( @@ -297,6 +323,17 @@ extern "C" { stepsParams: *mut ::std::os::raw::c_void, ); } +extern "C" { + #[link_name = "\u{1}_Z34extend_and_merkelize_custom_commitPvmmS_S_S_"] + pub fn extend_and_merkelize_custom_commit( + pStarks: *mut ::std::os::raw::c_void, + commitId: u64, + step: u64, + buffer: *mut ::std::os::raw::c_void, + pProof: *mut ::std::os::raw::c_void, + pBuffHelper: *mut ::std::os::raw::c_void, + ); +} extern "C" { #[link_name = "\u{1}_Z12commit_stagePvjmS_S_S_"] pub fn commit_stage( diff --git a/provers/starks-lib-c/src/ffi_starks.rs b/provers/starks-lib-c/src/ffi_starks.rs index 0300be2f..242e7a4b 100644 --- a/provers/starks-lib-c/src/ffi_starks.rs +++ b/provers/starks-lib-c/src/ffi_starks.rs @@ -131,6 +131,11 @@ pub fn get_map_totaln_c(p_stark_info: *mut c_void) -> u64 { unsafe { get_map_total_n(p_stark_info) } } +#[cfg(not(feature = "no_lib_link"))] +pub fn get_map_totaln_custom_commits_c(p_stark_info: *mut c_void, commit_id: u64) -> u64 { + unsafe { get_map_total_n_custom_commits(p_stark_info, commit_id) } +} + #[cfg(not(feature = "no_lib_link"))] pub fn get_n_airvals_c(p_stark_info: *mut c_void) -> u64 { unsafe { get_n_airvals(p_stark_info) } @@ -146,6 +151,18 @@ pub fn get_n_evals_c(p_stark_info: *mut c_void) -> u64 { unsafe { get_n_evals(p_stark_info) } } +#[cfg(not(feature = "no_lib_link"))] +pub fn get_n_custom_commits_c(p_stark_info: *mut c_void) -> u64 { + unsafe { get_n_custom_commits(p_stark_info) } +} + +#[cfg(not(feature = "no_lib_link"))] +pub fn get_custom_commit_id_c(p_stark_info: *mut c_void, name: &str) -> u64 { + let name = CString::new(name).unwrap(); + + unsafe { get_custom_commit_id(p_stark_info, name.as_ptr() as *mut std::os::raw::c_char) } +} + #[cfg(not(feature = "no_lib_link"))] pub fn get_airgroupval_id_by_name_c(p_stark_info: *mut c_void, name: &str) -> i64 { let airgroupval_name = CString::new(name).unwrap(); @@ -405,6 +422,13 @@ pub fn treesGL_get_root_c(pStark: *mut c_void, index: u64, root: *mut c_void) { } } +#[cfg(not(feature = "no_lib_link"))] +pub fn treesGL_set_root_c(pStark: *mut c_void, index: u64, pProof: *mut c_void) { + unsafe { + treesGL_set_root(pStark, index, pProof); + } +} + #[cfg(not(feature = "no_lib_link"))] pub fn calculate_xdivxsub_c(p_stark: *mut c_void, xi_challenge: *mut c_void, xdivxsub: *mut c_void) { unsafe { @@ -438,6 +462,20 @@ pub fn calculate_impols_expressions_c(p_starks: *mut c_void, step: u64, p_steps_ } } +#[cfg(not(feature = "no_lib_link"))] +pub fn extend_and_merkelize_custom_commit_c( + p_starks: *mut c_void, + commit_id: u64, + step: u64, + buffer: *mut c_void, + p_proof: *mut c_void, + p_buff_helper: *mut c_void, +) { + unsafe { + extend_and_merkelize_custom_commit(p_starks, commit_id, step, buffer, p_proof, p_buff_helper); + } +} + #[cfg(not(feature = "no_lib_link"))] pub fn commit_stage_c( p_starks: *mut c_void, @@ -828,17 +866,17 @@ pub fn set_log_level_c(level: u64) { // ------------------------ #[cfg(feature = "no_lib_link")] pub fn save_challenges_c(_p_challenges: *mut std::os::raw::c_void, _global_info_file: &str, _output_dir: &str) { - trace!("{}: ··· {}", "ffi ", "save_challenges_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "save_challenges: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] pub fn save_publics_c(_n_publics: u64, _public_inputs: *mut c_void, _output_dir: &str) { - trace!("{}: ··· {}", "ffi ", "save_publics_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "save_publics: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] pub fn save_proof_values_c(_n_proof_values: u64, _proof_values: *mut c_void, _output_dir: &str) { - trace!("{}: ··· {}", "ffi ", "save_proof_values_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "save_proof_values: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] @@ -858,7 +896,7 @@ pub fn fri_proof_set_airgroup_values_c(_p_fri_proof: *mut c_void, _p_params: *mu trace!( "{}: ··· {}", "ffi ", - "fri_proof_set_airgroup_values_c: This is a mock call because there is no linked library" + "fri_proof_set_airgroup_values: This is a mock call because there is no linked library" ); } @@ -867,7 +905,7 @@ pub fn fri_proof_set_air_values_c(_p_fri_proof: *mut c_void, _p_params: *mut c_v trace!( "{}: ··· {}", "ffi ", - "fri_proof_set_air_values_c: This is a mock call because there is no linked library" + "fri_proof_set_air_values: This is a mock call because there is no linked library" ); } @@ -907,37 +945,59 @@ pub fn stark_info_new_c(_filename: &str) -> *mut c_void { #[cfg(feature = "no_lib_link")] pub fn get_stark_info_n_publics_c(_p_stark_info: *mut c_void) -> u64 { - trace!("{}: ··· {}", "ffi ", "get_stark_info_n_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "get_stark_info_n: This is a mock call because there is no linked library"); 100000000 } #[cfg(feature = "no_lib_link")] pub fn get_stark_info_n_c(_p_stark_info: *mut c_void) -> u64 { - trace!("{}: ··· {}", "ffi ", "get_stark_info_n_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "get_stark_info_n: This is a mock call because there is no linked library"); 100000000 } #[cfg(feature = "no_lib_link")] pub fn get_map_totaln_c(_p_stark_info: *mut c_void) -> u64 { - trace!("{}: ··· {}", "ffi ", "get_map_totaln_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "get_map_totaln: This is a mock call because there is no linked library"); + 100000000 +} + +#[cfg(feature = "no_lib_link")] +pub fn get_map_totaln_custom_commits_c(_p_stark_info: *mut c_void, _commit_id: u64) -> u64 { + trace!( + "{}: ··· {}", + "ffi ", + "get_map_totaln_custom_commits: This is a mock call because there is no linked library" + ); 100000000 } #[cfg(feature = "no_lib_link")] pub fn get_n_airvals_c(_p_stark_info: *mut c_void) -> u64 { - trace!("{}: ··· {}", "ffi ", "get_n_airvals_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "get_n_airvals: This is a mock call because there is no linked library"); 100000000 } #[cfg(feature = "no_lib_link")] pub fn get_n_airgroupvals_c(_p_stark_info: *mut c_void) -> u64 { - trace!("{}: ··· {}", "ffi ", "get_n_airgroupvals_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "get_n_airgroupvals: This is a mock call because there is no linked library"); 100000000 } #[cfg(feature = "no_lib_link")] pub fn get_n_evals_c(_p_stark_info: *mut c_void) -> u64 { - trace!("{}: ··· {}", "ffi ", "get_n_evals_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "get_n_evals: This is a mock call because there is no linked library"); + 100000000 +} + +#[cfg(feature = "no_lib_link")] +pub fn get_n_custom_commits_c(_p_stark_info: *mut c_void) -> u64 { + trace!("{}: ··· {}", "ffi ", "get_n_custom_commits: This is a mock call because there is no linked library"); + 100000000 +} + +#[cfg(feature = "no_lib_link")] +pub fn get_custom_commit_id_c(_p_stark_info: *mut c_void, _name: &str) -> u64 { + trace!("{}: ··· {}", "ffi ", "get_custom_commit_id: This is a mock call because there is no linked library"); 100000000 } @@ -946,14 +1006,14 @@ pub fn get_airgroupval_id_by_name_c(_p_stark_info: *mut c_void, _name: &str) -> trace!( "{}: ··· {}", "ffi ", - "get_airgroupval_id_by_name_c: This is a mock call because there is no linked library" + "get_airgroupval_id_by_name: This is a mock call because there is no linked library" ); 100000000 } #[cfg(feature = "no_lib_link")] pub fn get_airval_id_by_name_c(_p_stark_info: *mut c_void, _name: &str) -> i64 { - trace!("{}: ··· {}", "ffi ", "get_airval_id_by_name_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "get_airval_id_by_name: This is a mock call because there is no linked library"); 100000000 } @@ -970,7 +1030,7 @@ pub fn stark_info_free_c(_p_stark_info: *mut c_void) { #[cfg(feature = "no_lib_link")] pub fn prover_helpers_new_c(_p_stark_info: *mut c_void) -> *mut c_void { - trace!("{}: ··· {}", "ffi ", "prover_helpers_new_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "prover_helpers_new: This is a mock call because there is no linked library"); std::ptr::null_mut() } @@ -979,24 +1039,24 @@ pub fn prover_helpers_free_c(_p_prover_helpers: *mut c_void) {} #[cfg(feature = "no_lib_link")] pub fn load_const_pols_c(_pConstPolsAddress: *mut c_void, _const_filename: &str, _const_size: u64) { - trace!("{}: ··· {}", "ffi ", "load_const_pols_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "load_const_pols: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] pub fn get_const_tree_size_c(_pStarkInfo: *mut c_void) -> u64 { - trace!("{}: ··· {}", "ffi ", "get_const_tree_size_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "get_const_tree_size: This is a mock call because there is no linked library"); 1000000 } #[cfg(feature = "no_lib_link")] pub fn get_const_size_c(_pStarkInfo: *mut c_void) -> u64 { - trace!("{}: ··· {}", "ffi ", "get_const_size_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "get_const_size: This is a mock call because there is no linked library"); 1000000 } #[cfg(feature = "no_lib_link")] pub fn load_const_tree_c(_pConstPolsTreeAddress: *mut c_void, _tree_filename: &str, _const_tree_size: u64) { - trace!("{}: ··· {}", "ffi ", "load_const_tree_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "load_const_tree: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] @@ -1006,7 +1066,7 @@ pub fn calculate_const_tree_c( _pConstPolsTreeAddress: *mut c_void, _tree_filename: &str, ) { - trace!("{}: ··· {}", "ffi ", "calculate_const_tree_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "calculate_const_tree: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] @@ -1089,7 +1149,7 @@ pub fn set_hint_field_c( _hint_id: u64, _hint_field_name: &str, ) -> u64 { - trace!("{}: ··· {}", "ffi ", "set_hint_field_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "set_hint_field: This is a mock call because there is no linked library"); 0 } @@ -1120,6 +1180,11 @@ pub fn treesGL_get_root_c(_pStark: *mut c_void, _index: u64, _root: *mut c_void) trace!("{}: ··· {}", "ffi ", "treesGL_get_root: This is a mock call because there is no linked library"); } +#[cfg(feature = "no_lib_link")] +pub fn treesGL_set_root_c(_pStark: *mut c_void, _index: u64, _pProof: *mut c_void) { + trace!("{}: ··· {}", "ffi ", "treesGL_set_root: This is a mock call because there is no linked library"); +} + #[cfg(feature = "no_lib_link")] pub fn calculate_fri_polynomial_c(_p_starks: *mut c_void, _p_steps_params: *mut c_void) { trace!("mckzkevm: ··· {}", "calculate_fri_polynomial: This is a mock call because there is no linked library"); @@ -1139,6 +1204,22 @@ pub fn calculate_impols_expressions_c(_p_starks: *mut c_void, _step: u64, _p_ste ); } +#[cfg(feature = "no_lib_link")] +pub fn extend_and_merkelize_custom_commit_c( + _p_starks: *mut c_void, + _commit_id: u64, + _step: u64, + _buffer: *mut c_void, + _p_proof: *mut c_void, + _p_buff_helper: *mut c_void, +) { + trace!( + "{}: ··· {}", + "ffi ", + "extend_and_merkelize_custom_commit: This is a mock call because there is no linked library" + ); +} + #[cfg(feature = "no_lib_link")] pub fn commit_stage_c( _p_starks: *mut c_void, @@ -1153,7 +1234,7 @@ pub fn commit_stage_c( #[cfg(feature = "no_lib_link")] pub fn compute_lev_c(_p_stark: *mut c_void, _xi_challenge: *mut c_void, _lev: *mut c_void) { - trace!("{}: ··· {}", "ffi ", "compute_lev_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "compute_lev: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] @@ -1163,7 +1244,7 @@ pub fn compute_evals_c(_p_stark: *mut c_void, _params: *mut c_void, _lev: *mut c #[cfg(feature = "no_lib_link")] pub fn calculate_xdivxsub_c(_p_stark: *mut c_void, _xi_challenge: *mut c_void, _buffer: *mut c_void) { - trace!("{}: ··· {}", "ffi ", "calculate_xdivxsub_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "calculate_xdivxsub: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] @@ -1266,7 +1347,7 @@ pub fn get_permutations_c(_p_transcript: *mut c_void, _res: *mut u64, _n: u64, _ #[cfg(feature = "no_lib_link")] pub fn verify_constraints_c(_p_setup: *mut c_void, _p_steps_params: *mut c_void) -> *mut c_void { - trace!("{}: ··· {}", "ffi ", "verify_constraints_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "verify_constraints: This is a mock call because there is no linked library"); std::ptr::null_mut() } @@ -1281,7 +1362,7 @@ pub fn verify_global_constraints_c( trace!( "{}: ··· {}", "ffi ", - "verify_global_constraints_c: This is a mock call because there is no linked library" + "verify_global_constraints: This is a mock call because there is no linked library" ); true } @@ -1301,7 +1382,7 @@ pub fn get_hint_field_global_constraints_c( trace!( "{}: ··· {}", "ffi ", - "get_hint_field_global_constraints_c: This is a mock call because there is no linked library" + "get_hint_field_global_constraints: This is a mock call because there is no linked library" ); std::ptr::null_mut() } @@ -1317,7 +1398,7 @@ pub fn set_hint_field_global_constraints_c( trace!( "{}: ··· {}", "ffi ", - "set_hint_field_global_constraints_c: This is a mock call because there is no linked library" + "set_hint_field_global_constraints: This is a mock call because there is no linked library" ); 100000 } @@ -1332,7 +1413,7 @@ pub fn print_by_name_c( _last_print_value: u64, _return_values: bool, ) -> *mut c_void { - trace!("{}: ··· {}", "ffi ", "print_by_name_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "print_by_name: This is a mock call because there is no linked library"); std::ptr::null_mut() } @@ -1344,12 +1425,12 @@ pub fn print_expression_c( _first_print_value: u64, _last_print_value: u64, ) { - trace!("{}: ··· {}", "ffi ", "print_expression_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "print_expression: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] pub fn print_row_c(_p_setup_ctx: *mut c_void, _buffer: *mut c_void, _stage: u64, _row: u64) { - trace!("{}: ··· {}", "ffi ", "print_row_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "print_row: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] @@ -1364,19 +1445,19 @@ pub fn gen_recursive_proof_c( _global_info_file: &str, _airgroup_id: u64, ) -> *mut c_void { - trace!("{}: ··· {}", "ffi ", "gen_recursive_proof_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "gen_recursive_proof: This is a mock call because there is no linked library"); std::ptr::null_mut() } #[cfg(feature = "no_lib_link")] pub fn get_zkin_ptr_c(_zkin_file: &str) -> *mut c_void { - trace!("{}: ··· {}", "ffi ", "get_zkin_ptr_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "get_zkin_ptr: This is a mock call because there is no linked library"); std::ptr::null_mut() } #[cfg(feature = "no_lib_link")] pub fn add_recursive2_verkey_c(_p_zkin: *mut c_void, _recursive2_verkey: &str) -> *mut c_void { - trace!("{}: ··· {}", "ffi ", "add_recursive2_verkey_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "add_recursive2_verkey: This is a mock call because there is no linked library"); std::ptr::null_mut() } @@ -1390,7 +1471,7 @@ pub fn join_zkin_recursive2_c( _zkin2: *mut c_void, _stark_info_recursive2: *mut c_void, ) -> *mut c_void { - trace!("{}: ··· {}", "ffi ", "join_zkin_recursive2_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "join_zkin_recursive2: This is a mock call because there is no linked library"); std::ptr::null_mut() } @@ -1409,37 +1490,33 @@ pub fn join_zkin_final_c( #[cfg(feature = "no_lib_link")] pub fn get_serialized_proof_c(_zkin: *mut c_void, _size: *mut u64) -> *mut std::os::raw::c_char { - trace!("{}: ··· {}", "ffi ", "get_serialized_proof_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "get_serialized_proof: This is a mock call because there is no linked library"); std::ptr::null_mut() } #[cfg(feature = "no_lib_link")] pub fn deserialize_zkin_proof_c(_zkin_cstr: *mut std::os::raw::c_char) -> *mut c_void { - trace!( - "{}: ··· {}", - "ffi ", - "deserialize_zkin_proof_c: This is a mock call because there is no linked library" - ); + trace!("{}: ··· {}", "ffi ", "deserialize_zkin_proof: This is a mock call because there is no linked library"); std::ptr::null_mut() } #[cfg(feature = "no_lib_link")] pub fn get_zkin_proof_c(_zkin_file: *mut std::os::raw::c_char) -> *mut c_void { - trace!("{}: ··· {}", "ffi ", "get_zkin_proof_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "get_zkin_proof: This is a mock call because there is no linked library"); std::ptr::null_mut() } #[cfg(feature = "no_lib_link")] pub fn zkin_proof_free_c(_p_zkin_proof: *mut c_void) { - trace!("{}: ··· {}", "ffi ", "zkin_proof_free_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "zkin_proof_free: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] pub fn serialized_proof_free_c(_zkin_cstr: *mut std::os::raw::c_char) { - trace!("{}: ··· {}", "ffi ", "serialized_proof_free_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "serialized_proof_free: This is a mock call because there is no linked library"); } #[cfg(feature = "no_lib_link")] pub fn set_log_level_c(_level: u64) { - trace!("{}: ··· {}", "ffi ", "set_log_level_c: This is a mock call because there is no linked library"); + trace!("{}: ··· {}", "ffi ", "set_log_level: This is a mock call because there is no linked library"); }