Skip to content

Commit

Permalink
Merge pull request #231 from a16z/feat/mul-vec-uniform-acceleration
Browse files Browse the repository at this point in the history
feat: Accelerate mul_vec_uniform
  • Loading branch information
sragss authored Mar 28, 2024
2 parents f24810b + 9b96b8e commit 0bd01d4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
29 changes: 11 additions & 18 deletions jolt-core/src/r1cs/r1cs_shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,10 @@ impl<F: PrimeField> R1CSShape<F> {
&self,
full_witness_vector: &P,
num_steps: usize,
) -> Result<(Vec<F>, Vec<F>, Vec<F>), SpartanError> {
A: &mut Vec<F>,
B: &mut Vec<F>,
C: &mut Vec<F>
) -> Result<(), SpartanError> {
if full_witness_vector.len() != (self.num_io + self.num_vars) * num_steps {
return Err(SpartanError::InvalidWitnessLength);
}
Expand Down Expand Up @@ -226,45 +229,35 @@ impl<F: PrimeField> R1CSShape<F> {

// computes a product between a sparse uniform matrix represented by `M` and a vector `z`
let sparse_matrix_vec_product_uniform =
|M: &Vec<(usize, usize, F)>, num_rows: usize| -> Vec<F> {
|M: &Vec<(usize, usize, F)>, result: &mut Vec<F>, num_rows: usize| {
let row_pointers = get_row_pointers(M);

let mut result: Vec<F> = vec![F::zero(); num_steps * num_rows];

let span = tracing::span!(
tracing::Level::TRACE,
"sparse_matrix_vec_product_uniform::multiply_row_vecs"
);
let _enter = span.enter();
result
.par_chunks_mut(num_steps)
.take(num_rows) // Inputs have been padded to a power of 2 -- only have num_steps * num_rows total non-zero fields
.enumerate()
.for_each(|(row_index, row_output)| {
let row = &M[row_pointers[row_index]..row_pointers[row_index + 1]];
multiply_row_vec_uniform(row, row_output, num_steps);
});

result
};

let (mut Az, (mut Bz, mut Cz)) = rayon::join(
|| sparse_matrix_vec_product_uniform(&self.A, self.num_cons),
rayon::join(
|| sparse_matrix_vec_product_uniform(&self.A, A, self.num_cons),
|| {
rayon::join(
|| sparse_matrix_vec_product_uniform(&self.B, self.num_cons),
|| sparse_matrix_vec_product_uniform(&self.C, self.num_cons),
|| sparse_matrix_vec_product_uniform(&self.B, B, self.num_cons),
|| sparse_matrix_vec_product_uniform(&self.C, C, self.num_cons),
)
},
);

// pad each Az, Bz, Cz to the next power of 2
let m = max(Az.len(), max(Bz.len(), Cz.len())).next_power_of_two();
rayon::join(
|| Az.resize(m, F::zero()),
|| rayon::join(|| Bz.resize(m, F::zero()), || Cz.resize(m, F::zero())),
);

Ok((Az, Bz, Cz))
Ok(())
}

/// Pads the R1CSShape so that the number of variables is a power of two
Expand Down
28 changes: 23 additions & 5 deletions jolt-core/src/r1cs/spartan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,32 @@ impl<F: PrimeField, G: CurveGroup<ScalarField = F>> UniformSpartanProof<F, G> {
.map(|_i| <Transcript as ProofTranscript<G>>::challenge_scalar(transcript, b"t"))
.collect::<Vec<F>>();

let (Az, Bz, Cz) = key.shape_single_step.multiply_vec_uniform(
let combined_witness_size = (key.num_steps * key.shape_single_step.num_cons).next_power_of_two();
let A_z = allocate_vec_in_background(F::zero(), combined_witness_size);
let B_z = allocate_vec_in_background(F::zero(), combined_witness_size);
let C_z = allocate_vec_in_background(F::zero(), combined_witness_size);

let mut poly_tau = DensePolynomial::new(EqPolynomial::new(tau).evals());

let span = tracing::span!(tracing::Level::TRACE, "wait_join");
let _enter = span.enter();
let mut A_z = A_z.join().unwrap();
let mut B_z = B_z.join().unwrap();
let mut C_z = C_z.join().unwrap();
drop(_enter);

key.shape_single_step.multiply_vec_uniform(
&segmented_padded_witness,
key.num_steps,
&mut A_z,
&mut B_z,
&mut C_z
)?;
let mut poly_Az = DensePolynomial::new(Az);
let mut poly_Bz = DensePolynomial::new(Bz);
let mut poly_Cz = DensePolynomial::new(Cz);
let mut poly_tau = DensePolynomial::new(EqPolynomial::new(tau).evals());
let mut poly_Az = DensePolynomial::new(A_z);
let mut poly_Bz = DensePolynomial::new(B_z);
let mut poly_Cz = DensePolynomial::new(C_z);



let comb_func_outer = |A: &F, B: &F, C: &F, D: &F| -> F {
// Below is an optimized form of: *A * (*B * *C - *D)
Expand Down

0 comments on commit 0bd01d4

Please sign in to comment.