diff --git a/.gitignore b/.gitignore index 8356b80..2ce79d2 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ outputs/temp/ *.pdf scripts/__pycache__/ .DS_Store +.idea diff --git a/src/whir/fs_utils.rs b/src/whir/fs_utils.rs new file mode 100644 index 0000000..245cbb5 --- /dev/null +++ b/src/whir/fs_utils.rs @@ -0,0 +1,30 @@ +use crate::domain::Domain; +use crate::utils::dedup; +use crate::whir::parameters::{RoundConfig, WhirConfig}; +use ark_crypto_primitives::merkle_tree; +use ark_ff::FftField; +use nimue::{ByteChallenges, ProofResult}; + +pub fn get_challenge_stir_queries( + domain_size: usize, + folding_factor: usize, + num_queries: usize, + transcript: &mut T, +) -> ProofResult> +where + T: ByteChallenges, +{ + let folded_domain_size = domain_size / (1 << folding_factor); + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + let mut queries = vec![0u8; num_queries * domain_size_bytes]; + transcript.fill_challenge_bytes(&mut queries)?; + let indices = queries.chunks_exact(domain_size_bytes).map(|chunk| { + let mut result = 0; + for byte in chunk { + result <<= 8; + result |= *byte as usize; + } + result % folded_domain_size + }); + Ok(dedup(indices)) +} diff --git a/src/whir/iopattern.rs b/src/whir/iopattern.rs index d0e2485..74ac8ed 100644 --- a/src/whir/iopattern.rs +++ b/src/whir/iopattern.rs @@ -47,18 +47,24 @@ where .challenge_scalars(1, "initial_combination_randomness") .add_sumcheck(params.folding_factor, params.starting_folding_pow_bits); + let mut folded_domain_size = params.starting_domain.folded_size(params.folding_factor); + for r in ¶ms.round_parameters { + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; self = self .add_bytes(32, "merkle_digest") .add_ood(r.ood_samples) - .challenge_bytes(32, "stir_queries_seed") + .challenge_bytes(r.num_queries * domain_size_bytes, "stir_queries") .pow(r.pow_bits) .challenge_scalars(1, "combination_randomness") .add_sumcheck(params.folding_factor, r.folding_pow_bits); + folded_domain_size /= 2; } + let domain_size_bytes = ((folded_domain_size * 2 - 1).ilog2() as usize + 7) / 8; + self.add_scalars(1 << params.final_sumcheck_rounds, "final_coeffs") - .challenge_bytes(32, "final_queries_seed") + .challenge_bytes(domain_size_bytes * params.final_queries, "final_queries") .pow(params.final_pow_bits) .add_sumcheck(params.final_sumcheck_rounds, params.final_folding_pow_bits) } diff --git a/src/whir/mod.rs b/src/whir/mod.rs index 8243508..39f14cd 100644 --- a/src/whir/mod.rs +++ b/src/whir/mod.rs @@ -8,6 +8,7 @@ pub mod iopattern; pub mod parameters; pub mod prover; pub mod verifier; +mod fs_utils; #[derive(Debug, Clone)] pub struct Statement { diff --git a/src/whir/prover.rs b/src/whir/prover.rs index e4b4591..2505050 100644 --- a/src/whir/prover.rs +++ b/src/whir/prover.rs @@ -21,6 +21,7 @@ use nimue::{ use nimue_pow::{self, PoWChallenge}; use rand::{Rng, SeedableRng}; +use crate::whir::fs_utils::get_challenge_stir_queries; #[cfg(feature = "parallel")] use rayon::prelude::*; @@ -131,13 +132,13 @@ where merlin.add_scalars(folded_coefficients.coeffs())?; // Final verifier queries and answers - let mut queries_seed = [0u8; 32]; - merlin.fill_challenge_bytes(&mut queries_seed)?; - let mut final_gen = rand_chacha::ChaCha20Rng::from_seed(queries_seed); - let final_challenge_indexes = utils::dedup((0..self.0.final_queries).map(|_| { - final_gen.gen_range(0..round_state.domain.folded_size(self.0.folding_factor)) - })); - + let final_challenge_indexes = get_challenge_stir_queries( + round_state.domain.size(), + self.0.folding_factor, + self.0.final_queries, + merlin, + )?; + let merkle_proof = round_state .prev_merkle .generate_multi_proof(final_challenge_indexes.clone()) @@ -214,13 +215,12 @@ where } // STIR queries - let mut stir_queries_seed = [0u8; 32]; - merlin.fill_challenge_bytes(&mut stir_queries_seed)?; - let mut stir_gen = rand_chacha::ChaCha20Rng::from_seed(stir_queries_seed); - let stir_challenges_indexes = - utils::dedup((0..round_params.num_queries).map(|_| { - stir_gen.gen_range(0..round_state.domain.folded_size(self.0.folding_factor)) - })); + let stir_challenges_indexes = get_challenge_stir_queries( + round_state.domain.size(), + self.0.folding_factor, + round_params.num_queries, + merlin, + )?; let domain_scaled_gen = round_state .domain .backing_domain diff --git a/src/whir/verifier.rs b/src/whir/verifier.rs index 2fb131b..5f72a50 100644 --- a/src/whir/verifier.rs +++ b/src/whir/verifier.rs @@ -10,6 +10,8 @@ use nimue::{ use nimue_pow::{self, PoWChallenge}; use rand::{Rng, SeedableRng}; +use super::{parameters::WhirConfig, Statement, WhirProof}; +use crate::whir::fs_utils::get_challenge_stir_queries; use crate::{ parameters::FoldType, poly_utils::{coeffs::CoefficientList, eq_poly_outside, fold::compute_fold, MultilinearPoint}, @@ -17,8 +19,6 @@ use crate::{ utils::{self, expand_randomness}, }; -use super::{parameters::WhirConfig, Statement, WhirProof}; - pub struct Verifier where F: FftField, @@ -147,13 +147,13 @@ where arthur.fill_next_scalars(&mut ood_answers)?; } - let mut stir_queries_seed = [0u8; 32]; - arthur.fill_challenge_bytes(&mut stir_queries_seed)?; - let mut stir_gen = rand_chacha::ChaCha20Rng::from_seed(stir_queries_seed); - let folded_domain_size = domain_size / (1 << self.params.folding_factor); - let stir_challenges_indexes = utils::dedup( - (0..round_params.num_queries).map(|_| stir_gen.gen_range(0..folded_domain_size)), - ); + let stir_challenges_indexes = get_challenge_stir_queries( + domain_size, + self.params.folding_factor, + round_params.num_queries, + arthur, + )?; + let stir_challenges_points = stir_challenges_indexes .iter() .map(|index| exp_domain_gen.pow([*index as u64])) @@ -222,13 +222,12 @@ where let final_coefficients = CoefficientList::new(final_coefficients); // Final queries verify - let mut queries_seed = [0u8; 32]; - arthur.fill_challenge_bytes(&mut queries_seed)?; - let mut final_gen = rand_chacha::ChaCha20Rng::from_seed(queries_seed); - let folded_domain_size = domain_size / (1 << self.params.folding_factor); - let final_randomness_indexes = utils::dedup( - (0..self.params.final_queries).map(|_| final_gen.gen_range(0..folded_domain_size)), - ); + let final_randomness_indexes = get_challenge_stir_queries( + domain_size, + self.params.folding_factor, + self.params.final_queries, + arthur, + )?; let final_randomness_points = final_randomness_indexes .iter() .map(|index| exp_domain_gen.pow([*index as u64]))