From fc385b636c53439a694c10869d40d2065958e430 Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Mon, 30 Dec 2024 14:49:54 -0500 Subject: [PATCH] Allow num_inputs to be different per witness_sec --- spartan_parallel/src/custom_dense_mlpoly.rs | 297 +++++++++---------- spartan_parallel/src/instance.rs | 123 ++++---- spartan_parallel/src/lib.rs | 84 +++--- spartan_parallel/src/r1csinstance.rs | 124 ++------ spartan_parallel/src/r1csproof.rs | 113 +++---- spartan_parallel/src/sparse_mlpoly.rs | 3 +- spartan_parallel/src/sumcheck.rs | 33 +-- spartan_parallel/writeups/proofs_overview.md | 6 +- 8 files changed, 329 insertions(+), 454 deletions(-) diff --git a/spartan_parallel/src/custom_dense_mlpoly.rs b/spartan_parallel/src/custom_dense_mlpoly.rs index 677803e7..3f658f55 100644 --- a/spartan_parallel/src/custom_dense_mlpoly.rs +++ b/spartan_parallel/src/custom_dense_mlpoly.rs @@ -14,23 +14,23 @@ const MODE_X: usize = 4; // Customized Dense ML Polynomials for Data-Parallelism // These Dense ML Polys are aimed for space-efficiency by removing the 0s for invalid (p, q, w, x) quadruple -// Dense polynomial with variable order: p, q_rev, w, x_rev +// Dense polynomial with variable order: p, q, w, x // Used by Z_poly in r1csproof -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Hash)] pub struct DensePolynomialPqx { - num_instances: usize, // num_instances is a power of 2 and num_instances / 2 < Z.len() <= num_instances - num_proofs: Vec, - max_num_proofs: usize, - pub num_witness_secs: usize, // num_witness_secs is a power of 2 and num_witness_secs / 2 < Z[.][.].len() <= num_witness_secs - num_inputs: Vec, - max_num_inputs: usize, - pub Z: Vec>>>, // Evaluations of the polynomial in all the 2^num_vars Boolean inputs of order (p, q_rev, w, x_rev) - // Let Q_max = max_num_proofs, assume that for a given P, num_proofs[P] = Q_i, then let STEP = Q_max / Q_i, - // Z(P, y, .) is only non-zero if y is a multiple of STEP, so Z[P][j][.] actually stores Z(P, j*STEP, .) - // The same applies to X + // All metadata might not be a power of 2 + pub num_instances: usize, + pub num_proofs: Vec, // P + pub num_witness_secs: usize, + pub num_inputs: Vec>, // P x W + pub num_vars_p: usize, // log(P.next_power_of_two()) + pub num_vars_q: usize, + pub num_vars_w: usize, + pub num_vars_x: usize, + pub Z: Vec>>>, // Evaluations of the polynomial in all the 2^num_vars Boolean inputs of order (p, q, w, x) } -fn fold_rq(proofs: &mut [Vec>], r_q: &[S], step: usize, mut q: usize, w: usize, x: usize) { +fn fold_rq(proofs: &mut [Vec>], r_q: &[S], step: usize, mut q: usize, w: usize, x: &Vec) { for r in r_q { let r1 = S::field_one() - r.clone(); let r2 = r.clone(); @@ -38,69 +38,55 @@ fn fold_rq(proofs: &mut [Vec>], r_q: &[S], step q = q.div_ceil(2); (0..q).for_each(|q| { (0..w).for_each(|w| { - (0..x).for_each(|x| { + (0..x[w]).for_each(|x| { proofs[q * step][w][x] = r1 * proofs[2 * q * step][w][x] + r2 * proofs[(2 * q + 1) * step][w][x]; }); }); }); } - - /* - if lvl > final_lvl { - fold_rq(proofs, r_q, 2 * idx, step, lvl - 1, final_lvl, w, x); - fold_rq(proofs, r_q, 2 * idx + step, step, lvl - 1, final_lvl, w, x); - - let r1 = S::field_one() - r_q[lvl - 1]; - let r2 = r_q[lvl - 1]; - - (0..w).for_each(|w| { - (0..x).for_each(|x| { - proofs[idx][w][x] = r1 * proofs[idx * 2][w][x] + r2 * proofs[idx * 2 + step][w][x]; - }); - }); - } else { - // base level. do nothing - } - */ } impl DensePolynomialPqx { - // Assume z_mat is of form (p, q_rev, x_rev), construct DensePoly + // Assume z_mat is of form (p, q_rev, x), construct DensePoly pub fn new( - z_mat: Vec>>>, - num_proofs: Vec, - max_num_proofs: usize, - num_inputs: Vec, - max_num_inputs: usize, + z_mat: Vec>>>, ) -> Self { - let num_instances = z_mat.len().next_power_of_two(); - let num_witness_secs = z_mat[0][0].len().next_power_of_two(); + let num_instances = z_mat.len(); + let num_proofs: Vec = (0..num_instances).map(|p| z_mat[p].len()).collect(); + let num_witness_secs = z_mat[0][0].len(); + let num_inputs: Vec> = (0..num_instances).map(|p| + (0..num_witness_secs).map(|w| z_mat[p][0][w].len()).collect() + ).collect(); + // Sortedness check: num_proofs and num_inputs[p] are sorted in decreasing order + assert!((0..num_instances - 1).fold(true, |b, i| b && num_proofs[i] >= num_proofs[i + 1])); + for w in &num_inputs { + assert!((0..num_witness_secs - 1).fold(true, |b, i| b && w[i] >= w[i + 1])); + } + + let num_vars_p = num_instances.next_power_of_two().log_2(); + let num_vars_q = num_proofs.iter().max().unwrap().next_power_of_two().log_2(); + let num_vars_w = num_witness_secs.next_power_of_two().log_2(); + let num_vars_x = num_inputs.iter().map(|i| i.iter().max().unwrap()).max().unwrap().next_power_of_two().log_2(); DensePolynomialPqx { num_instances, num_proofs, - max_num_proofs, num_witness_secs, num_inputs, - max_num_inputs, + num_vars_p, + num_vars_q, + num_vars_w, + num_vars_x, Z: z_mat, } } pub fn len(&self) -> usize { - return self.num_instances * self.max_num_proofs * self.max_num_inputs; + return self.num_vars_p.pow2() * self.num_vars_q.pow2() * self.num_vars_w.pow2() * self.num_vars_x.pow2(); } - // Given (p, q_rev, x_rev) return Z[p][q_rev][x_rev] - pub fn index(&self, p: usize, q_rev: usize, w: usize, x_rev: usize) -> S { - if p < self.Z.len() - && q_rev < self.Z[p].len() - && w < self.Z[p][q_rev].len() - && x_rev < self.Z[p][q_rev][w].len() - { - return self.Z[p][q_rev][w][x_rev]; - } else { - return S::field_zero(); - } + // Given (p, q, w, x) return Z[p][q][w][x], DO NOT CHECK FOR OUT OF BOUND + pub fn index(&self, p: usize, q: usize, w: usize, x: usize) -> S { + return self.Z[p][q][w][x]; } // Given (p, q, w, x) and a mode, return Z[p*][q*][w*][x*] @@ -108,26 +94,25 @@ impl DensePolynomialPqx { // Mode = 2 ==> q* = 2q for low, 2q + 1 // Mode = 3 ==> w* = 2w for low, 2w + 1 // Mode = 4 ==> x* = 2x for low, 2x + 1 - // Assume p*, q*, w*, x*, within bound + // Assume p*, q*, w*, x* are within bound pub fn index_low(&self, p: usize, q: usize, w: usize, x: usize, mode: usize) -> S { let ZERO = S::field_zero(); match mode { - MODE_P => { if 2 * p >= self.Z.len() { ZERO } else { self.Z[2 * p][q][w][x] } } - MODE_Q => self.Z[p][2 * q][w][x], - MODE_W => { if 2 * w >= self.Z[p][q].len() { ZERO } else { self.Z[p][q][2 * w][x] } } - MODE_X => self.Z[p][q][w][2 * x], - _ => unreachable!() + MODE_P => { if self.num_instances == 1 { self.Z[0][q][w][x] } else if 2 * p >= self.num_instances { ZERO } else { self.Z[2 * p][q][w][x] } } + MODE_Q => { if 2 * q >= self.num_proofs[p] { ZERO } else { self.Z[p][2 * q][w][x] } }, + MODE_W => { if 2 * w >= self.num_witness_secs { ZERO } else { self.Z[p][q][2 * w][x] } } + MODE_X => { if 2 * x >= self.num_inputs[p][w] { ZERO } else { self.Z[p][q][w][2 * x] } }, + _ => unreachable!() } } - pub fn index_high(&self, p: usize, q: usize, w: usize, x: usize, mode: usize) -> S { let ZERO = S::field_zero(); match mode { - MODE_P => { if self.num_instances == 1 { self.Z[0][q][w][x] } else if 2 * p + 1 >= self.Z.len() { ZERO } else { self.Z[2 * p + 1][q][w][x] } } - MODE_Q => { if self.num_proofs[p] == 1 { ZERO } else { self.Z[p][2 * q + 1][w][x] } } - MODE_W => { if 2 * w + 1 >= self.Z[p][q].len() { ZERO } else { self.Z[p][q][2 * w + 1][x] } } - MODE_X => { if self.num_inputs[p] == 1 { ZERO } else { self.Z[p][q][w][2 * x + 1] } } - _ => unreachable!() + MODE_P => { if self.num_instances == 1 { self.Z[0][q][w][x] } else if 2 * p + 1 >= self.num_instances { ZERO } else { self.Z[2 * p + 1][q][w][x] } } + MODE_Q => { if 2 * q + 1 >= self.num_proofs[p] { ZERO } else { self.Z[p][2 * q + 1][w][x] } } + MODE_W => { if 2 * w + 1 >= self.num_witness_secs { ZERO } else { self.Z[p][q][2 * w + 1][x] } } + MODE_X => { if 2 * x + 1 >= self.num_inputs[p][w] { ZERO } else { self.Z[p][q][w][2 * x + 1] } } + _ => unreachable!() } } @@ -137,108 +122,99 @@ impl DensePolynomialPqx { // Mode = 3 ==> Bound last variable of "w" section to r // Mode = 4 ==> Bound last variable of "x" section to r pub fn bound_poly(&mut self, r: &S, mode: usize) { - match mode { - MODE_P => { self.bound_poly_p(r); } - MODE_Q => { self.bound_poly_q(r); } - MODE_W => { self.bound_poly_w(r); } - MODE_X => { self.bound_poly_x(r); } - _ => { panic!("DensePolynomialPqx bound failed: unrecognized mode {}!", mode); } - } + match mode { + MODE_P => { self.bound_poly_p(r); } + MODE_Q => { self.bound_poly_q(r); } + MODE_W => { self.bound_poly_w(r); } + MODE_X => { self.bound_poly_x(r); } + _ => { panic!("DensePolynomialPqx bound failed: unrecognized mode {}!", mode); } + } } // Bound the last variable of "p" section to r // We are only allowed to bound "p" if we have bounded the entire q and x section pub fn bound_poly_p(&mut self, r: &S) { - let ZERO = S::field_zero(); - assert_eq!(self.max_num_proofs, 1); - assert_eq!(self.max_num_inputs, 1); - self.num_instances /= 2; - for p in 0..self.num_instances { - for w in 0..min(self.num_witness_secs, self.Z[p][0].len()) { - let Z_low = if 2 * p < self.Z.len() { self.Z[2 * p][0][w][0] } else { ZERO }; - let Z_high = if 2 * p + 1 < self.Z.len() { self.Z[2 * p + 1][0][w][0] } else { ZERO }; - self.Z[p][0][w][0] = Z_low + r.clone() * (Z_high - Z_low); - } + assert!(self.num_vars_p >= 1); + assert_eq!(self.num_vars_q, 0); + assert_eq!(self.num_vars_x, 0); + let new_num_instances = self.num_instances.div_ceil(2); + for p in 0..new_num_instances { + for w in 0..self.num_witness_secs { + let Z_low = self.index_low(p, 0, w, 0, MODE_P); + let Z_high = self.index_high(p, 0, w, 0, MODE_P); + self.Z[p][0][w][0] = Z_low + r.clone() * (Z_high - Z_low); } + } + self.num_instances = new_num_instances; + self.num_vars_p -= 1; } // Bound the last variable of "q" section to r pub fn bound_poly_q(&mut self, r: &S) { - let ONE = S::field_one(); - self.max_num_proofs /= 2; - - for p in 0..min(self.num_instances, self.Z.len()) { - if self.num_proofs[p] == 1 { - for w in 0..min(self.num_witness_secs, self.Z[p][0].len()) { - for x in 0..self.num_inputs[p] { - self.Z[p][0][w][x] *= ONE - r.clone(); - } - } - } else { - self.num_proofs[p] /= 2; - for q in 0..self.num_proofs[p] { - for w in 0..min(self.num_witness_secs, self.Z[p][q].len()) { - for x in 0..self.num_inputs[p] { - self.Z[p][q][w][x] = self.Z[p][2 * q][w][x] + r.clone() * (self.Z[p][2 * q + 1][w][x] - self.Z[p][2 * q][w][x]); - } + assert!(self.num_vars_q >= 1); + for p in 0..self.num_instances { + let new_num_proofs = self.num_proofs[p].div_ceil(2); + for q in 0..new_num_proofs { + for w in 0..self.num_witness_secs { + for x in 0..self.num_inputs[p][w] { + let Z_low = self.index_low(p, q, w, x, MODE_Q); + let Z_high = self.index_high(p, q, w, x, MODE_Q); + self.Z[p][q][w][x] = Z_low + r.clone() * (Z_high - Z_low); } } } + self.num_proofs[p] = new_num_proofs; } + self.num_vars_q -= 1; } // Bound the last variable of "w" section to r + // We are only allowed to bound "w" if we have bounded the entire x section pub fn bound_poly_w(&mut self, r: &S) { - let ZERO = S::field_zero(); - self.num_witness_secs /= 2; - - for p in 0..min(self.num_instances, self.Z.len()) { + assert!(self.num_vars_w >= 1); + assert_eq!(self.num_vars_x, 0); + let new_num_witness_secs = self.num_witness_secs.div_ceil(2); + for p in 0..self.num_instances { for q in 0..self.num_proofs[p] { - for w in 0..self.num_witness_secs { - for x in 0..self.num_inputs[p] { - let Z_low = if 2 * w < self.Z[p][q].len() { self.Z[p][q][2 * w][x] } else { ZERO }; - let Z_high = if 2 * w + 1 < self.Z[p][q].len() { self.Z[p][q][2 * w + 1][x] } else { ZERO }; - self.Z[p][q][w][x] = Z_low + r.clone() * (Z_high - Z_low); - } + for w in 0..new_num_witness_secs { + let Z_low = self.index_low(p, q, w, 0, MODE_W); + let Z_high = self.index_high(p, q, w, 0, MODE_W); + self.Z[p][q][w][0] = Z_low + r.clone() * (Z_high - Z_low); } } } + self.num_witness_secs = new_num_witness_secs; + self.num_vars_w -= 1; } // Bound the last variable of "x" section to r pub fn bound_poly_x(&mut self, r: &S) { - let ONE = S::field_one(); - self.max_num_inputs /= 2; - - for p in 0..min(self.num_instances, self.Z.len()) { - if self.num_inputs[p] == 1 { - for q in 0..self.num_proofs[p] { - for w in 0..min(self.num_witness_secs, self.Z[p][q].len()) { - self.Z[p][q][w][0] *= ONE - r.clone(); - } - } - } else { - self.num_inputs[p] /= 2; - for q in 0..self.num_proofs[p] { - for w in 0..min(self.num_witness_secs, self.Z[p][q].len()) { - for x in 0..self.num_inputs[p] { - self.Z[p][q][w][x] = self.Z[p][q][w][2 * x] + r.clone() * (self.Z[p][q][w][2 * x + 1] - self.Z[p][q][w][2 * x]); - } - } + // assert!(self.num_vars_x >= 1); + for p in 0..self.num_instances { + for w in 0..self.num_witness_secs { + let new_num_inputs = self.num_inputs[p][w].div_ceil(2); + for q in 0..self.num_proofs[p] { + for x in 0..new_num_inputs { + let Z_low = self.index_low(p, q, w, x, MODE_X); + let Z_high = self.index_high(p, q, w, x, MODE_X); + self.Z[p][q][w][x] = Z_low + r.clone() * (Z_high - Z_low); } } + self.num_inputs[p][w] = new_num_inputs; } + } + if self.num_vars_x >= 1 { + self.num_vars_x -= 1; + } } // Bound the entire "p" section to r_p in reverse // Must occur after r_q's are bounded - pub fn bound_poly_vars_rp(&mut self, - r_p: &[S], - ) { - for r in r_p { - self.bound_poly_p(r); - } + pub fn bound_poly_vars_rp(&mut self, r_p: &[S]) { + for r in r_p { + self.bound_poly_p(r); } + } // Bound the entire "q" section to r_q in reverse pub fn bound_poly_vars_rq_parallel( @@ -268,7 +244,7 @@ impl DensePolynomialPqx { // single proof matrix dimension W x X let num_witness_secs = min(self.num_witness_secs, inst[0].len()); - let num_inputs = self.num_inputs[p]; + let num_inputs = &self.num_inputs[p]; // Divide rq into sub, final, and left_over let sub_rq = &r_q[0..sub_levels]; @@ -294,7 +270,7 @@ impl DensePolynomialPqx { // the series of random challenges exceeds the total number of variables let c = left_over_rq.into_iter().fold(S::field_one(), |acc, n| acc * (S::field_one() - *n)); for w in 0..inst[0].len() { - for x in 0..inst[0][0].len() { + for x in 0..inst[0][w].len() { inst[0][w][x] *= c; } } @@ -303,33 +279,26 @@ impl DensePolynomialPqx { inst }).collect::>>>>(); - self.max_num_proofs = 1; + self.num_vars_q = 0; self.num_proofs = vec![1; self.num_instances]; } // Bound the entire "q" section to r_q in reverse - // Must occur after r_q's are bounded - pub fn bound_poly_vars_rq(&mut self, - r_q: &[S], - ) { + pub fn bound_poly_vars_rq(&mut self, r_q: &[S]) { for r in r_q { self.bound_poly_q(r); } } // Bound the entire "w" section to r_w in reverse - pub fn bound_poly_vars_rw(&mut self, - r_w: &[S], - ) { + pub fn bound_poly_vars_rw(&mut self, r_w: &[S]) { for r in r_w { self.bound_poly_w(r); } } // Bound the entire "x_rev" section to r_x - pub fn bound_poly_vars_rx(&mut self, - r_x: &[S], - ) { + pub fn bound_poly_vars_rx(&mut self, r_x: &[S]) { for r in r_x { self.bound_poly_x(r); } @@ -351,22 +320,28 @@ impl DensePolynomialPqx { // Convert to a (p, q_rev, x_rev) regular dense poly of form (p, q, x) pub fn to_dense_poly(&self) -> DensePolynomial { - let ZERO = S::field_zero(); - let mut Z_poly = vec![ZERO; self.num_instances * self.max_num_proofs * self.num_witness_secs * self.max_num_inputs]; - for p in 0..min(self.num_instances, self.Z.len()) { - for q in 0..self.num_proofs[p] { - for w in 0..min(self.num_witness_secs, self.Z[p][q].len()) { - for x in 0..self.num_inputs[p] { - Z_poly[ - p * self.max_num_proofs * self.num_witness_secs * self.max_num_inputs - + q * self.num_witness_secs * self.max_num_inputs - + w * self.max_num_inputs - + x - ] = self.Z[p][q][w][x]; - } + let ZERO = S::field_zero(); + + let p_space = self.num_vars_p.pow2(); + let q_space = self.num_vars_q.pow2(); + let w_space = self.num_vars_w.pow2(); + let x_space = self.num_vars_x.pow2(); + + let mut Z_poly = vec![ZERO; p_space * q_space * w_space * x_space]; + for p in 0..self.num_instances { + for q in 0..self.num_proofs[p] { + for w in 0..self.num_witness_secs { + for x in 0..self.num_inputs[p][w] { + Z_poly[ + p * q_space * w_space * x_space + + q * w_space * x_space + + w * x_space + + x + ] = self.Z[p][q][w][x]; } } } - DensePolynomial::new(Z_poly) + } + DensePolynomial::new(Z_poly) } } \ No newline at end of file diff --git a/spartan_parallel/src/instance.rs b/spartan_parallel/src/instance.rs index 879d79e0..d046bd47 100644 --- a/spartan_parallel/src/instance.rs +++ b/spartan_parallel/src/instance.rs @@ -26,7 +26,7 @@ pub struct Instance { pub digest: Vec, } -impl Instance { +impl Instance { /// Constructs a new `Instance` and an associated satisfying assignment pub fn new( num_instances: usize, @@ -38,6 +38,8 @@ impl Instance { B: &Vec>, C: &Vec>, ) -> Result, R1CSError> { + let ZERO = S::field_zero(); + let (max_num_vars_padded, num_vars_padded, max_num_cons_padded, num_cons_padded) = { let max_num_vars_padded = { let mut max_num_vars_padded = max_num_vars; @@ -82,12 +84,7 @@ impl Instance { } } - ( - max_num_vars_padded, - num_vars_padded, - max_num_cons_padded, - num_cons_padded, - ) + (max_num_vars_padded, num_vars_padded, max_num_cons_padded, num_cons_padded) }; let bytes_to_scalar = @@ -124,7 +121,7 @@ impl Instance { // we do not need to pad otherwise because the dummy constraints are implicit in the sum-check protocol if num_cons[b] == 0 || num_cons[b] == 1 { for i in tups.len()..num_cons_padded[b] { - mat.push((i, num_vars[b], S::field_zero())); + mat.push((i, num_vars[b], ZERO)); } } @@ -245,10 +242,10 @@ impl Instance { /// Verify the correctness of each block execution, as well as extracting all memory operations /// /// Input composition: (if every segment exists) - /// INPUT + VAR Challenges BLOCK_W2 BLOCK_W3 BLOCK_W3_SHIFTED - /// 0 1 2 IOW +1 +2 +3 +4 +5 | 0 1 2 3 | 0 1 2 3 4 NIU 1 2 3 2NP +1 +2 +3 +4 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 - /// v i0 ... PA0 PD0 ... VA0 VD0 ... | tau r r^2 ... | _ _ ZO r*i1 ... MR MC MR ... MR1 MR2 MR3 MC MR1 ... | v x pi D pi D pi D | v x pi D pi D pi D - /// INPUT PHY VIR INPUT PHY VIR INPUT PHY VIR + /// INPUT + VAR BLOCK_W2 Challenges BLOCK_W3 BLOCK_W3_SHIFTED + /// 0 1 2 IOW +1 +2 +3 +4 +5 | 0 1 2 3 4 NIU 1 2 3 2NP +1 +2 +3 +4 | 0 1 2 3 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 + /// v i0 ... PA0 PD0 ... VA0 VD0 ... | _ _ ZO r*i1 ... MR MC MR ... MR1 MR2 MR3 MC MR1 ... | tau r r^2 ... | v x pi D pi D pi D | v x pi D pi D pi D + /// INPUT PHY VIR INPUT PHY VIR INPUT PHY VIR /// /// VAR: /// We assume that the witnesses are of the following format: @@ -271,7 +268,7 @@ impl Instance { /// - VMR3 = r^3 * VT /// - VMC = (1 or VMC[i-1]) * (tau - VA - VMR1 - VMR2 - VMR3) /// The final product is stored in X = MC[NV - 1] - /// + /// /// If in COMMIT_MODE, commit instance by num_vars_per_block, rounded to the nearest power of four pub fn gen_block_inst( num_instances: usize, @@ -306,20 +303,12 @@ impl Instance { max_size_per_group.insert(next_group_size(*num_vars), num_vars.next_power_of_two()); } } - num_vars_per_block - .iter() - .map(|i| { - max_size_per_group - .get(&next_group_size(*i)) - .unwrap() - .clone() - }) - .collect() + num_vars_per_block.iter().map(|i| max_size_per_group.get(&next_group_size(*i)).unwrap().clone()).collect() } else { vec![num_vars; num_instances] }; - if PRINT_SIZE { + if PRINT_SIZE && !COMMIT_MODE { println!("\n\n--\nBLOCK INSTS"); println!( "{:10} {:>4} {:>4} {:>4} {:>4}", @@ -348,37 +337,30 @@ impl Instance { let V_VD = |b: usize, i: usize| io_width + 2 * num_phy_ops[b] + 4 * i + 1; let V_VL = |b: usize, i: usize| io_width + 2 * num_phy_ops[b] + 4 * i + 2; let V_VT = |b: usize, i: usize| io_width + 2 * num_phy_ops[b] + 4 * i + 3; - // in CHALLENGES, not used if !has_mem_op - let V_tau = |b: usize| num_vars_padded_per_block[b]; - let V_r = |b: usize, i: usize| num_vars_padded_per_block[b] + i; // in BLOCK_W2 / INPUT_W2 let V_input_dot_prod = |b: usize, i: usize| { if i == 0 { V_input(0) } else { - 2 * num_vars_padded_per_block[b] + 2 + i + num_vars_padded_per_block[b] + 2 + i } }; - let V_output_dot_prod = - |b: usize, i: usize| 2 * num_vars_padded_per_block[b] + 2 + (num_inputs_unpadded - 1) + i; + let V_output_dot_prod = |b: usize, i: usize| num_vars_padded_per_block[b] + 2 + (num_inputs_unpadded - 1) + i; // in BLOCK_W2 / PHY_W2 - let V_PMR = - |b: usize, i: usize| 2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * i; - let V_PMC = - |b: usize, i: usize| 2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * i + 1; + let V_PMR = |b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * i; + let V_PMC = |b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * i + 1; // in BLOCK_W2 / VIR_W2 - let V_VMR1 = |b: usize, i: usize| { - 2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i - }; - let V_VMR2 = |b: usize, i: usize| { - 2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 1 - }; - let V_VMR3 = |b: usize, i: usize| { - 2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 2 - }; - let V_VMC = |b: usize, i: usize| { - 2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 3 - }; + let V_VMR1 = + |b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i; + let V_VMR2 = + |b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 1; + let V_VMR3 = + |b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 2; + let V_VMC = + |b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 3; + // in CHALLENGES, not used if !has_mem_op + let V_tau = |b: usize| 2 * num_vars_padded_per_block[b]; + let V_r = |b: usize, i: usize| 2 * num_vars_padded_per_block[b] + i; // in BLOCK_W3 let V_v = |b: usize| 3 * num_vars_padded_per_block[b]; let V_x = |b: usize| 3 * num_vars_padded_per_block[b] + 1; @@ -703,7 +685,7 @@ impl Instance { B_list.push(B); C_list.push(C); - if PRINT_SIZE { + if PRINT_SIZE && !COMMIT_MODE { let max_nnz = max(tmp_nnz_A, max(tmp_nnz_B, tmp_nnz_C)); let total_var = num_vars_per_block[b] + 2 * num_inputs_unpadded.next_power_of_two() @@ -724,7 +706,7 @@ impl Instance { } } - if PRINT_SIZE { + if PRINT_SIZE && !COMMIT_MODE { println!("Total Num of Blocks: {}", num_instances); println!("Total Inst Commit Size: {}", total_inst_commit_size); println!("Total Var Commit Size: {}", total_var_commit_size); @@ -744,10 +726,7 @@ impl Instance { max_cons_per_group.insert(num_vars_padded_per_block[i], block_num_cons[i]); } } - num_vars_padded_per_block - .iter() - .map(|i| max_cons_per_group.get(i).unwrap().clone()) - .collect() + num_vars_padded_per_block.iter().map(|i| max_cons_per_group.get(i).unwrap().clone()).collect() } else { block_num_cons } @@ -759,10 +738,7 @@ impl Instance { block_max_num_cons, num_cons_padded_per_block, block_num_vars, - num_vars_padded_per_block - .into_iter() - .map(|i| 8 * i) - .collect(), + num_vars_padded_per_block.into_iter().map(|i| 8 * i).collect(), &A_list, &B_list, &C_list, @@ -816,9 +792,14 @@ impl Instance { /// D2 = D1 * (ls[i+1] - STORE) /// Where STORE = 0 /// Input composition: - /// Op[k] Op[k + 1] D2 & bits of ts[k + 1] - ts[k] - /// 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 - /// v D1 a d ls ts _ _ | v D1 a d ls ts _ _ | D2 EQ B0 B1 ... + /// bits of ts[k + 1] - ts[k] Op[k] Op[k + 1] + /// 0 1 2 3 4 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 + /// D2 EQ B0 B1 ... | v D1 a d ls ts _ _ | v D1 a d ls ts _ _ + /// + /// If ADDR_NONCONSEC, address comparison of VIR uses <= instead of +1, with the following expression + /// ts | addr + /// 0 1 2 3 4 | 0 1 2 3 4 5 + /// D2 EQ B0 B1 ... | D4 INV EQ B0 B1 ... pub fn gen_pairwise_check_inst( max_ts_width: usize, mem_addr_ts_bits_size: usize, @@ -834,6 +815,7 @@ impl Instance { "", "con", "var", "nnz", "exec" ); } + // Variable used by printing let mut total_inst_commit_size = 0; let mut total_var_commit_size = 0; @@ -841,7 +823,6 @@ impl Instance { let pairwise_check_num_vars = max(8, mem_addr_ts_bits_size); let pairwise_check_max_num_cons = 8 + max_ts_width; - let pairwise_check_num_cons = vec![2, 4, 8 + max_ts_width]; let pairwise_check_num_non_zero_entries: usize = max(13 + max_ts_width, 5 + 2 * max_ts_width); let pairwise_check_inst = { @@ -972,23 +953,24 @@ impl Instance { let (A, B, C) = { let width = pairwise_check_num_vars; - let V_valid = 0; + // TS_BITS + let V_D2 = 0; + let V_EQ = 1; + let V_B = |i| 2 + i; + // OP[K], OP[K + 1] + let V_valid = width; let V_cnst = V_valid; - let V_D1 = 1; - let V_addr = 2; - let V_data = 3; - let V_ls = 4; - let V_ts = 5; - let V_D2 = 2 * width; - let V_EQ = 2 * width + 1; - let V_B = |i| 2 * width + 2 + i; + let V_D1 = width + 1; + let V_addr = width + 2; + let V_data = width + 3; + let V_ls = width + 4; + let V_ts = width + 5; let mut A: Vec<(usize, usize, [u8; 32])> = Vec::new(); let mut B: Vec<(usize, usize, [u8; 32])> = Vec::new(); let mut C: Vec<(usize, usize, [u8; 32])> = Vec::new(); let mut num_cons = 0; - // Sortedness // (v[k] - 1) * v[k + 1] = 0 (A, B, C) = Instance::::gen_constr( A, @@ -1000,6 +982,7 @@ impl Instance { vec![], ); num_cons += 1; + // Sortedness // D1[k] = v[k + 1] * (1 - addr[k + 1] + addr[k]) (A, B, C) = Instance::::gen_constr( A, @@ -1403,4 +1386,4 @@ impl Instance { perm_root_inst, ) } -} +} \ No newline at end of file diff --git a/spartan_parallel/src/lib.rs b/spartan_parallel/src/lib.rs index 04ab7040..b23320c9 100644 --- a/spartan_parallel/src/lib.rs +++ b/spartan_parallel/src/lib.rs @@ -15,8 +15,6 @@ extern crate digest; extern crate merlin; extern crate rand; extern crate sha3; - -#[cfg(feature = "multicore")] extern crate rayon; mod custom_dense_mlpoly; @@ -433,6 +431,7 @@ impl ProverWitnessSecInfo { } } + // Empty ProverWitnessSecInfo fn dummy() -> ProverWitnessSecInfo { ProverWitnessSecInfo { num_inputs: Vec::new(), @@ -441,6 +440,16 @@ impl ProverWitnessSecInfo { } } + // Zero ProverWitnessSecInfo + fn pad() -> ProverWitnessSecInfo { + let ZERO = S::field_zero(); + ProverWitnessSecInfo { + num_inputs: vec![1], + w_mat: vec![vec![vec![ZERO]]], + poly_w: vec![DensePolynomial::new(vec![ZERO])], + } + } + // Concatenate the components in the given order to a new prover witness sec fn concat(components: Vec<&ProverWitnessSecInfo>) -> ProverWitnessSecInfo { let mut num_inputs = Vec::new(); @@ -531,6 +540,13 @@ impl VerifierWitnessSecInfo { } } + fn pad() -> VerifierWitnessSecInfo { + VerifierWitnessSecInfo { + num_inputs: vec![1], + num_proofs: vec![1], + } + } + // Concatenate the components in the given order to a new verifier witness sec fn concat(components: Vec<&VerifierWitnessSecInfo>) -> VerifierWitnessSecInfo { let mut num_inputs = Vec::new(); @@ -1099,9 +1115,6 @@ impl SNARK { let index: Vec = inst_sorter.iter().map(|i| i.index).collect(); let block_inst_unsorted = block_inst.clone(); block_inst.sort(block_num_instances, &index); - let block_num_vars: Vec = (0..block_num_instances) - .map(|i| block_num_vars[index[i]]) - .collect(); let block_num_phy_ops: Vec = (0..block_num_instances) .map(|i| block_num_phy_ops[index[i]]) .collect(); @@ -2021,8 +2034,8 @@ impl SNARK { let timer_proof = Timer::new("Block Correctness Extract"); let block_wit_secs = vec![ &block_vars_prover, - &perm_w0_prover, &block_w2_prover, + &perm_w0_prover, &block_w3_prover, &block_w3_shifted_prover, ]; @@ -2034,7 +2047,6 @@ impl SNARK { block_max_num_proofs, block_num_proofs, num_vars, - &block_num_vars, block_wit_secs, &block_inst.inst, transcript, @@ -2118,27 +2130,28 @@ impl SNARK { .max() .unwrap() .clone(); - let (pairwise_prover, inst_map) = ProverWitnessSecInfo::merge(vec![ + let (pairwise_w0_prover, inst_map) = ProverWitnessSecInfo::merge(vec![ &perm_exec_w3_prover, &addr_phy_mems_prover, - &addr_vir_mems_prover, + &addr_ts_bits_prover, ]); - let (pairwise_shifted_prover, _) = ProverWitnessSecInfo::merge(vec![ + let (pairwise_w1_prover, _) = ProverWitnessSecInfo::merge(vec![ &perm_exec_w3_shifted_prover, &addr_phy_mems_shifted_prover, - &addr_vir_mems_shifted_prover, + &addr_vir_mems_prover, ]); - let addr_ts_bits_prover = { - let mut components = vec![&perm_w0_prover; inst_map.len()]; + let dummy_w2 = ProverWitnessSecInfo::pad(); + let pairwise_w2_prover = { + let mut components = vec![&dummy_w2; inst_map.len()]; for i in 0..inst_map.len() { if inst_map[i] == 2 { - components[i] = &addr_ts_bits_prover; + components[i] = &addr_vir_mems_shifted_prover; } } ProverWitnessSecInfo::concat(components) }; - let pairwise_num_instances = pairwise_prover.w_mat.len(); - let pairwise_num_proofs: Vec = pairwise_prover.w_mat.iter().map(|i| i.len()).collect(); + let pairwise_num_instances = pairwise_w0_prover.w_mat.len(); + let pairwise_num_proofs: Vec = pairwise_w0_prover.w_mat.iter().map(|i| i.len()).collect(); let (pairwise_check_r1cs_sat_proof, pairwise_check_challenges) = { let (proof, pairwise_check_challenges) = { R1CSProof::prove( @@ -2146,11 +2159,10 @@ impl SNARK { pairwise_size, &pairwise_num_proofs, max(8, mem_addr_ts_bits_size), - &vec![max(8, mem_addr_ts_bits_size); pairwise_num_instances], vec![ - &pairwise_prover, - &pairwise_shifted_prover, - &addr_ts_bits_prover, + &pairwise_w0_prover, + &pairwise_w1_prover, + &pairwise_w2_prover, ], &pairwise_check_inst.inst, transcript, @@ -2272,7 +2284,6 @@ impl SNARK { perm_size, &perm_root_num_proofs, num_ios, - &vec![num_ios; perm_root_num_instances], vec![ &perm_w0_prover, &perm_root_w1_prover, @@ -2727,9 +2738,7 @@ impl SNARK { let mut block_num_proofs: Vec = inst_sorter.iter().map(|i| i.num_exec).collect(); // index[i] = j => the original jth entry should now be at the ith position let block_index: Vec = inst_sorter.iter().map(|i| i.index).collect(); - let block_num_vars: Vec = (0..block_num_instances) - .map(|i| block_num_vars[block_index[i]]) - .collect(); + let block_num_vars: Vec = (0..block_num_instances).map(|i| block_num_vars[block_index[i]]).collect(); let block_num_phy_ops: Vec = (0..block_num_instances) .map(|i| block_num_phy_ops[block_index[i]]) .collect(); @@ -3038,8 +3047,8 @@ impl SNARK { let timer_sat_proof = Timer::new("Block Correctness Extract Sat"); let block_wit_secs = vec![ &block_vars_verifier, - &perm_w0_verifier, &block_w2_verifier, + &perm_w0_verifier, &block_w3_verifier, &block_w3_shifted_verifier, ]; @@ -3113,27 +3122,28 @@ impl SNARK { .max() .unwrap() .clone(); - let (pairwise_verifier, inst_map) = VerifierWitnessSecInfo::merge(vec![ + let (pairwise_w0_verifier, inst_map) = VerifierWitnessSecInfo::merge(vec![ &perm_exec_w3_verifier, &addr_phy_mems_verifier, - &addr_vir_mems_verifier, + &addr_ts_bits_verifier ]); - let (pairwise_shifted_verifier, _) = VerifierWitnessSecInfo::merge(vec![ + let (pairwise_w1_verifier, _) = VerifierWitnessSecInfo::merge(vec![ &perm_exec_w3_shifted_verifier, &addr_phy_mems_shifted_verifier, - &addr_vir_mems_shifted_verifier, + &addr_vir_mems_verifier, ]); - let addr_ts_bits_verifier = { - let mut components = vec![&perm_w0_verifier; inst_map.len()]; + let dummy_w2 = VerifierWitnessSecInfo::pad(); + let pairwise_w2_verifier = { + let mut components = vec![&dummy_w2; inst_map.len()]; for i in 0..inst_map.len() { if inst_map[i] == 2 { - components[i] = &addr_ts_bits_verifier; + components[i] = &addr_vir_mems_shifted_verifier; } } VerifierWitnessSecInfo::concat(components) }; - let pairwise_num_instances = pairwise_verifier.num_proofs.len(); - let pairwise_num_proofs: Vec = pairwise_verifier.num_proofs.clone(); + let pairwise_num_instances = pairwise_w0_verifier.num_proofs.len(); + let pairwise_num_proofs: Vec = pairwise_w0_verifier.num_proofs.clone(); let pairwise_check_challenges = self.pairwise_check_r1cs_sat_proof.verify( pairwise_num_instances, @@ -3141,9 +3151,9 @@ impl SNARK { &pairwise_num_proofs, max(8, mem_addr_ts_bits_size), vec![ - &pairwise_verifier, - &pairwise_shifted_verifier, - &addr_ts_bits_verifier, + &pairwise_w0_verifier, + &pairwise_w1_verifier, + &pairwise_w2_verifier, ], pairwise_check_num_cons, &self.pairwise_check_inst_evals_bound_rp, diff --git a/spartan_parallel/src/r1csinstance.rs b/spartan_parallel/src/r1csinstance.rs index 9ea77495..8329969f 100644 --- a/spartan_parallel/src/r1csinstance.rs +++ b/spartan_parallel/src/r1csinstance.rs @@ -220,7 +220,6 @@ impl R1CSInstance { num_instances: usize, num_proofs: Vec, max_num_proofs: usize, - num_inputs: Vec, max_num_inputs: usize, max_num_cons: usize, num_cons: Vec, @@ -252,7 +251,6 @@ impl R1CSInstance { vec![self.A_list[p_inst].multiply_vec_disjoint_rounds( num_cons[p_inst].clone(), max_num_inputs, - num_inputs[p], &z_list[q], )] }) @@ -263,7 +261,6 @@ impl R1CSInstance { vec![self.B_list[p_inst].multiply_vec_disjoint_rounds( num_cons[p_inst].clone(), max_num_inputs, - num_inputs[p], &z_list[q], )] }) @@ -274,7 +271,6 @@ impl R1CSInstance { vec![self.C_list[p_inst].multiply_vec_disjoint_rounds( num_cons[p_inst].clone(), max_num_inputs, - num_inputs[p], &z_list[q], )] }) @@ -282,27 +278,9 @@ impl R1CSInstance { } ( - DensePolynomialPqx::new( - Az, - num_proofs.clone(), - max_num_proofs, - num_cons.clone(), - max_num_cons, - ), - DensePolynomialPqx::new( - Bz, - num_proofs.clone(), - max_num_proofs, - num_cons.clone(), - max_num_cons, - ), - DensePolynomialPqx::new( - Cz, - num_proofs, - max_num_proofs, - num_cons.clone(), - max_num_cons, - ), + DensePolynomialPqx::new(Az), + DensePolynomialPqx::new(Bz), + DensePolynomialPqx::new(Cz), ) } @@ -357,7 +335,7 @@ impl R1CSInstance { num_rows: &Vec, num_segs: usize, max_num_cols: usize, - num_cols: &Vec, + num_cols: &Vec>, evals: &[S], // Output in p, q, w, i format, where q section has length 1 ) -> ( @@ -372,88 +350,20 @@ impl R1CSInstance { self.max_num_vars ); - ( - (0..self.num_instances) - .into_par_iter() - .map(|p| { - let evals_A = self.A_list[p].compute_eval_table_sparse_disjoint_rounds( - evals, - num_rows[p], - num_segs, - max_num_cols, - num_cols[p], - ); - vec![evals_A] - }) - .collect(), - (0..self.num_instances) - .into_par_iter() - .map(|p| { - let evals_B = self.B_list[p].compute_eval_table_sparse_disjoint_rounds( - evals, - num_rows[p], - num_segs, - max_num_cols, - num_cols[p], - ); - vec![evals_B] - }) - .collect(), - (0..self.num_instances) - .into_par_iter() - .map(|p| { - let evals_C = self.C_list[p].compute_eval_table_sparse_disjoint_rounds( - evals, - num_rows[p], - num_segs, - max_num_cols, - num_cols[p], - ); - vec![evals_C] - }) - .collect(), - ) + let mut evals_A_list = Vec::new(); + let mut evals_B_list = Vec::new(); + let mut evals_C_list = Vec::new(); + for p in 0..self.num_instances { + let num_cols = *num_cols[p].iter().max().unwrap(); + let evals_A = self.A_list[p].compute_eval_table_sparse_disjoint_rounds(evals, num_rows[p], num_segs, max_num_cols, num_cols); + let evals_B = self.B_list[p].compute_eval_table_sparse_disjoint_rounds(evals, num_rows[p], num_segs, max_num_cols, num_cols); + let evals_C = self.C_list[p].compute_eval_table_sparse_disjoint_rounds(evals, num_rows[p], num_segs, max_num_cols, num_cols); + evals_A_list.push(vec![evals_A]); + evals_B_list.push(vec![evals_B]); + evals_C_list.push(vec![evals_C]); + } - // let evals_A_list = (0..self.num_instances) - // .into_par_iter() - // .map(|p| { - // let evals_A = self.A_list[p].compute_eval_table_sparse_disjoint_rounds( - // evals, - // num_rows[p], - // num_segs, - // max_num_cols, - // num_cols[p], - // ); - // vec![evals_A] - // }).collect(); - - // let evals_B_list = (0..self.num_instances) - // .into_par_iter() - // .map(|p| { - // let evals_B = self.B_list[p].compute_eval_table_sparse_disjoint_rounds( - // evals, - // num_rows[p], - // num_segs, - // max_num_cols, - // num_cols[p], - // ); - // vec![evals_B] - // }).collect(); - - // let evals_C_list = (0..self.num_instances) - // .into_par_iter() - // .map(|p| { - // let evals_C = self.C_list[p].compute_eval_table_sparse_disjoint_rounds( - // evals, - // num_rows[p], - // num_segs, - // max_num_cols, - // num_cols[p], - // ); - // vec![evals_C] - // }).collect(); - - // (evals_A_list, evals_B_list, evals_C_list) + (evals_A_list, evals_B_list, evals_C_list) } // If IS_BLOCK, ry is truncated starting at the third entry diff --git a/spartan_parallel/src/r1csproof.rs b/spartan_parallel/src/r1csproof.rs index 175666a6..cc9ea8d3 100644 --- a/spartan_parallel/src/r1csproof.rs +++ b/spartan_parallel/src/r1csproof.rs @@ -11,10 +11,9 @@ use crate::scalar::SpartanExtensionField; use crate::{ProverWitnessSecInfo, VerifierWitnessSecInfo}; use merlin::Transcript; use serde::{Deserialize, Serialize}; -use std::cmp::max; +use std::cmp::min; use std::iter::zip; use rayon::prelude::*; -use std::sync::{Arc, Mutex}; #[derive(Serialize, Deserialize, Debug)] pub struct R1CSProof { @@ -77,7 +76,7 @@ impl R1CSProof { num_rounds_p: usize, single_inst: bool, num_witness_secs: usize, - num_inputs: Vec, + num_inputs: Vec>, claim: &S, evals_eq: &mut DensePolynomial, evals_ABC: &mut DensePolynomialPqx, @@ -117,7 +116,6 @@ impl R1CSProof { num_proofs: &Vec, // Number of inputs of the combined Z matrix max_num_inputs: usize, - num_inputs: &Vec, // WITNESS_SECS // How many sections does each Z vector have? // num_witness_secs can be between 1 - 8 @@ -150,9 +148,19 @@ impl R1CSProof { assert_eq!(*p, p.next_power_of_two()); assert!(*p <= max_num_proofs); } - for i in num_inputs { - assert_eq!(*i, i.next_power_of_two()); - assert!(*i <= max_num_inputs); + // Construct num_inputs as P x W + // Note: w.num_inputs[p_w] might exceed max_num_inputs, but only the first max_num_inputs entries are used + let mut num_inputs: Vec> = (0..num_instances).map(|p| witness_secs.iter().map(|w| { + let p_w = if w.num_inputs.len() == 1 { 0 } else { p }; + min(w.num_inputs[p_w], max_num_inputs) + }).collect()).collect(); + // Number of inputs must be in decreasing order between witness segments + for p in 0..num_instances { + for w in (1..witness_secs.len()).rev() { + if num_inputs[p][w - 1] < num_inputs[p][w] { + num_inputs[p][w - 1] = num_inputs[p][w] + } + } } // Number of instances is either one or matches num_instances assert!(inst.get_num_instances() == 1 || inst.get_num_instances() == num_instances); @@ -191,13 +199,13 @@ impl R1CSProof { let p_w = if ws.w_mat.len() == 1 { 0 } else { p }; let q_w = if ws.w_mat[p_w].len() == 1 { 0 } else { q }; - let r_w = if ws.num_inputs[p_w] < num_inputs[p] { - let padding = std::iter::repeat(S::field_zero()).take(num_inputs[p] - ws.num_inputs[p_w]).collect::>(); + let r_w = if ws.num_inputs[p_w] < num_inputs[p][w] { + let padding = std::iter::repeat(S::field_zero()).take(num_inputs[p][w] - ws.num_inputs[p_w]).collect::>(); let mut r = ws.w_mat[p_w][q_w].clone(); r.extend(padding); r } else { - ws.w_mat[p_w][q_w].iter().take(num_inputs[p]).cloned().collect::>() + ws.w_mat[p_w][q_w].iter().take(num_inputs[p][w]).cloned().collect::>() }; r_w }).collect::>>() @@ -226,7 +234,6 @@ impl R1CSProof { num_instances, num_proofs.clone(), max_num_proofs, - num_inputs.clone(), max_num_inputs, num_cons, block_num_cons.clone(), @@ -308,7 +315,11 @@ impl R1CSProof { evals_ABC.push(vec![Vec::new()]); for w in 0..num_witness_secs { evals_ABC[p][0].push(Vec::new()); - for i in 0..num_inputs[p] { + // If single instance, need to find the maximum num_inputs + let num_inputs = if inst.get_num_instances() == 1 { + num_inputs.iter().map(|n| n[w]).max().unwrap() + } else { num_inputs[p][w] }; + for i in 0..num_inputs { evals_ABC[p][0][w].push( r_A * evals_A[p][0][w][i] + r_B * evals_B[p][0][w][i] + r_C * evals_C[p][0][w][i], ); @@ -317,24 +328,12 @@ impl R1CSProof { } evals_ABC }; - let mut ABC_poly = DensePolynomialPqx::new( - evals_ABC, - vec![1; num_instances], - 1, - num_inputs.clone(), - max_num_inputs, - ); + let mut ABC_poly = DensePolynomialPqx::new(evals_ABC); timer_tmp.stop(); let timer_tmp = Timer::new("prove_z_gen"); // Construct a p * q * len(z) matrix Z and bound it to r_q - let mut Z_poly = DensePolynomialPqx::new( - z_mat, - num_proofs.clone(), - max_num_proofs, - num_inputs.clone(), - max_num_inputs, - ); + let mut Z_poly = DensePolynomialPqx::new(z_mat); timer_tmp.stop(); let timer_tmp = Timer::new("prove_z_bind"); Z_poly.bound_poly_vars_rq_parallel(&rq_rev); @@ -398,34 +397,37 @@ impl R1CSProof { eval_vars_at_ry_list.push(Vec::new()); for p in 0..wit_sec_num_instance { - poly_list.push(&w.poly_w[p]); - num_proofs_list.push(w.w_mat[p].len()); - num_inputs_list.push(w.num_inputs[p]); - // Depending on w.num_inputs[p], ry_short can be two different values - let ry_short = { - // if w.num_inputs[p] >= num_inputs, need to pad 0's to the front of ry + if w.num_inputs[p] > 1 { + poly_list.push(&w.poly_w[p]); + num_proofs_list.push(w.w_mat[p].len()); + num_inputs_list.push(w.num_inputs[p]); + // Depending on w.num_inputs[p], ry_short can be two different values + let ry_short = { + // if w.num_inputs[p] >= num_inputs, need to pad 0's to the front of ry + if w.num_inputs[p] >= max_num_inputs { + let ry_pad = vec![ZERO; w.num_inputs[p].log_2() - max_num_inputs.log_2()]; + [ry_pad, ry.clone()].concat() + } + // Else ry_short is the last w.num_inputs[p].log_2() entries of ry + // thus, to obtain the actual ry, need to multiply by (1 - ry0)(1 - ry1)..., which is ry_factors[num_rounds_y - w.num_inputs[p]] + else { + ry[num_rounds_y - w.num_inputs[p].log_2()..].to_vec() + } + }; + let rq_short = rq[num_rounds_q - num_proofs_list[num_proofs_list.len() - 1].log_2()..].to_vec(); + let r = &[rq_short, ry_short.clone()].concat(); + let eval_vars_at_ry = poly_list[poly_list.len() - 1].evaluate(r); + Zr_list.push(eval_vars_at_ry); if w.num_inputs[p] >= max_num_inputs { - let ry_pad = vec![ZERO; w.num_inputs[p].log_2() - max_num_inputs.log_2()]; - [ry_pad, ry.clone()].concat() - } - // Else ry_short is the last w.num_inputs[p].log_2() entries of ry - // thus, to obtain the actual ry, need to multiply by (1 - ry0)(1 - ry1)..., which is ry_factors[num_rounds_y - w.num_inputs[p]] - else { - ry[num_rounds_y - w.num_inputs[p].log_2()..].to_vec() + eval_vars_at_ry_list[i].push(eval_vars_at_ry); + } else { + eval_vars_at_ry_list[i].push(eval_vars_at_ry * ry_factors[num_rounds_y - w.num_inputs[p].log_2()]); } - }; - let rq_short = - rq[num_rounds_q - num_proofs_list[num_proofs_list.len() - 1].log_2()..].to_vec(); - let r = &[rq_short, ry_short.clone()].concat(); - let eval_vars_at_ry = poly_list[poly_list.len() - 1].evaluate(r); - Zr_list.push(eval_vars_at_ry); - if w.num_inputs[p] >= max_num_inputs { - eval_vars_at_ry_list[i].push(eval_vars_at_ry); + raw_eval_vars_at_ry_list[i].push(eval_vars_at_ry); } else { - eval_vars_at_ry_list[i] - .push(eval_vars_at_ry * ry_factors[num_rounds_y - w.num_inputs[p].log_2()]); + eval_vars_at_ry_list[i].push(ZERO); + raw_eval_vars_at_ry_list[i].push(ZERO); } - raw_eval_vars_at_ry_list[i].push(eval_vars_at_ry); } } @@ -644,9 +646,14 @@ impl R1CSProof { let w = witness_secs[i]; let wit_sec_num_instance = w.num_proofs.len(); for p in 0..wit_sec_num_instance { - num_proofs_list.push(w.num_proofs[p]); - num_inputs_list.push(w.num_inputs[p]); - eval_Zr_list.push(self.eval_vars_at_ry_list[i][p]); + if w.num_inputs[p] > 1 { + // comm_list.push(&w.comm_w[p]); + num_proofs_list.push(w.num_proofs[p]); + num_inputs_list.push(w.num_inputs[p]); + eval_Zr_list.push(self.eval_vars_at_ry_list[i][p]); + } else { + assert_eq!(self.eval_vars_at_ry_list[i][p], ZERO); + } } } diff --git a/spartan_parallel/src/sparse_mlpoly.rs b/spartan_parallel/src/sparse_mlpoly.rs index b03297c8..ef766b1e 100644 --- a/spartan_parallel/src/sparse_mlpoly.rs +++ b/spartan_parallel/src/sparse_mlpoly.rs @@ -406,7 +406,6 @@ impl SparseMatPolynomial { &self, num_rows: usize, max_num_cols: usize, - _num_cols: usize, z: &Vec>, ) -> Vec { (0..self.M.len()) @@ -416,7 +415,7 @@ impl SparseMatPolynomial { let val = self.M[i].val.clone(); let w = col / max_num_cols; let y = col % max_num_cols; - (row, val * z[w][y]) + (row, if w < z.len() && y < z[w].len() { val * z[w][y] } else { S::field_zero() }) }) .fold(vec![S::field_zero(); num_rows], |mut Mz, (r, v)| { Mz[r] += v; diff --git a/spartan_parallel/src/sumcheck.rs b/spartan_parallel/src/sumcheck.rs index d096c01d..aa2f2974 100644 --- a/spartan_parallel/src/sumcheck.rs +++ b/spartan_parallel/src/sumcheck.rs @@ -325,7 +325,7 @@ impl SumcheckInstanceProof { num_rounds_p: usize, single_inst: bool, // indicates whether poly_B only has one instance num_witness_secs: usize, - mut num_inputs: Vec, + mut num_inputs: Vec>, poly_A: &mut DensePolynomial, poly_B: &mut DensePolynomialPqx, poly_C: &mut DensePolynomialPqx, @@ -362,8 +362,8 @@ impl SumcheckInstanceProof { for p in 0..min(instance_len, num_inputs.len()) { let p_inst = if single_inst { 0 } else { p }; for w in 0..min(witness_secs_len, num_witness_secs) { - for y_rev in 0..inputs_len { - let val = poly_A[p] * poly_B.index(p_inst, 0, w, y_rev) * poly_C.index(p, 0, w, y_rev); + for y in 0..min(num_inputs[p_inst][w], num_inputs[p][w]) { + let val = poly_A[p] * poly_B.index(p_inst, 0, w, y) * poly_C.index(p, 0, w, y); expected += val; } } @@ -402,11 +402,11 @@ impl SumcheckInstanceProof { // So min(instance_len, num_proofs.len()) suffices for p in 0..min(instance_len, num_inputs.len()) { let p_inst = if single_inst { 0 } else { p }; - if mode == MODE_X && num_inputs[p] > 1 { - num_inputs[p] /= 2; - } for w in 0..min(witness_secs_len, num_witness_secs) { - for y in 0..num_inputs[p] { + if mode == MODE_X && num_inputs[p][w] > 1 { + num_inputs[p][w] /= 2; + } + for y in 0..num_inputs[p][w] { // evaluate A, B, C on p, w, y let (poly_A_low, poly_A_high) = match mode { MODE_X => (poly_A[p], poly_A[p]), @@ -558,21 +558,16 @@ impl SumcheckInstanceProof { // Mode = 2 ==> q // Mode = 4 ==> x let mode = if j < num_rounds_x_max { + cons_len = cons_len.div_ceil(2); MODE_X } else if j < num_rounds_x_max + num_rounds_q_max { + proof_len = proof_len.div_ceil(2); MODE_Q } else { + instance_len = instance_len.div_ceil(2); MODE_P }; - if cons_len > 1 { - cons_len /= 2 - } else if proof_len > 1 { - proof_len /= 2 - } else { - instance_len /= 2 - }; - let poly = { let mut eval_point_0 = ZERO; let mut eval_point_2 = ZERO; @@ -581,13 +576,9 @@ impl SumcheckInstanceProof { // We are guaranteed initially instance_len < num_proofs.len() < instance_len x 2 // So min(instance_len, num_proofs.len()) suffices for p in 0..min(instance_len, num_proofs.len()) { - if mode == MODE_X && num_cons[p] > 1 { - num_cons[p] /= 2; - } + if mode == MODE_X { num_cons[p] = num_cons[p].div_ceil(2); } // If q > num_proofs[p], the polynomials always evaluate to 0 - if mode == MODE_Q && num_proofs[p] > 1 { - num_proofs[p] /= 2; - } + if mode == MODE_Q { num_proofs[p] = num_proofs[p].div_ceil(2); } for q in 0..num_proofs[p] { for x in 0..num_cons[p] { // evaluate A, B, C, D on p, q, x diff --git a/spartan_parallel/writeups/proofs_overview.md b/spartan_parallel/writeups/proofs_overview.md index 87cd9951..a8df21c3 100644 --- a/spartan_parallel/writeups/proofs_overview.md +++ b/spartan_parallel/writeups/proofs_overview.md @@ -1,7 +1,7 @@ | Proofs | W0 | W1 | w2 | W3 | W4 | W5 | |--------|----|----|----|----|----|----| -| BLOCK_CORRECTNESS | block_vars | perm_w0 | block_input_w2 | block_w2 | block_w3 | block_shifted_w3 | +| BLOCK_CORRECTNESS | block_vars | block_w2 | perm_w0 | block_w3 | block_shifted_w3 | | CONSIS_CHECK | perm_exec_w3 | perm_exec_w3_shifted | | PHY_MEM_COHERE | addr_phy_mems | addr_phy_mems_shifted | -| VIR_MEM_COHERE | addr_vir_mems | addr_vir_mems_shifted | addr_ts_bits | -| PERM_ROOT | perm_w0 | perm_root_w2 | perm_root_w3 | perm_root_shifted_w3 | +| VIR_MEM_COHERE | addr_ts_bits | addr_vir_mems | addr_vir_mems_shifted | +| PERM_ROOT | perm_w0 | perm_root_w2 | perm_root_w3 | perm_root_shifted_w3 | \ No newline at end of file