From a3248480632c47233d0ebe81bf087dd7d1e50024 Mon Sep 17 00:00:00 2001 From: Hang Su Date: Wed, 13 Nov 2024 15:39:06 -0500 Subject: [PATCH] use simd inner product? --- pcs/benches/orion.rs | 77 ++++++++++++++++++--------- pcs/src/orion.rs | 119 ++++++++++++++++++++++++++++++++++-------- pcs/src/orion_test.rs | 62 ++++++++++++++-------- 3 files changed, 188 insertions(+), 70 deletions(-) diff --git a/pcs/benches/orion.rs b/pcs/benches/orion.rs index bb64c610..ab392a87 100644 --- a/pcs/benches/orion.rs +++ b/pcs/benches/orion.rs @@ -3,15 +3,15 @@ use std::{hint::black_box, ops::Mul, time::Duration}; use arith::{Field, FieldSerde, SimdField}; use ark_std::test_rng; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use gf2::{GF2x128, GF2x64, GF2}; -use gf2_128::GF2_128; -use mersenne31::{M31Ext3, M31x16, M31}; +use gf2::{GF2x128, GF2x64, GF2x8, GF2}; +use gf2_128::{GF2_128x8, GF2_128}; +use mersenne31::{M31Ext3, M31Ext3x16, M31x16, M31}; use pcs::{OrionPCS, OrionPCSSetup, PolynomialCommitmentScheme, ORION_CODE_PARAMETER_INSTANCE}; use polynomials::MultiLinearPoly; use transcript::{BytesHashTranscript, Keccak256hasher, Transcript}; use tynm::type_name; -fn committing_benchmark_helper( +fn committing_benchmark_helper( c: &mut Criterion, lowest_num_vars: usize, highest_num_vars: usize, @@ -19,6 +19,8 @@ fn committing_benchmark_helper( F: Field + FieldSerde, PackF: SimdField, EvalF: Field + FieldSerde + From + Mul, + IPPackF: SimdField, + IPPackEvalF: SimdField + Mul, T: Transcript, { let mut group = c.benchmark_group(format!( @@ -32,7 +34,7 @@ fn committing_benchmark_helper( for num_vars in lowest_num_vars..=highest_num_vars { let random_poly = MultiLinearPoly::::random(num_vars, &mut rng); - let orion_pp = OrionPCS::::gen_srs_for_testing( + let orion_pp = OrionPCS::::gen_srs_for_testing( &mut rng, &OrionPCSSetup { num_vars, @@ -45,10 +47,12 @@ fn committing_benchmark_helper( BenchmarkId::new(format!("{num_vars} variables"), num_vars), |b| { b.iter(|| { - _ = black_box(OrionPCS::::commit( - &orion_pp, - &random_poly, - )) + _ = black_box( + OrionPCS::::commit( + &orion_pp, + &random_poly, + ), + ) }) }, ) @@ -58,15 +62,25 @@ fn committing_benchmark_helper( } fn orion_committing_benchmark(c: &mut Criterion) { - committing_benchmark_helper::>( - c, 19, 25, - ); - committing_benchmark_helper::>( - c, 15, 20, - ); + committing_benchmark_helper::< + GF2, + GF2x128, + GF2_128, + GF2x8, + GF2_128x8, + BytesHashTranscript<_, Keccak256hasher>, + >(c, 19, 25); + committing_benchmark_helper::< + M31, + M31x16, + M31Ext3, + M31x16, + M31Ext3x16, + BytesHashTranscript<_, Keccak256hasher>, + >(c, 15, 20); } -fn opening_benchmark_helper( +fn opening_benchmark_helper( c: &mut Criterion, lowest_num_vars: usize, highest_num_vars: usize, @@ -74,6 +88,8 @@ fn opening_benchmark_helper( F: Field + FieldSerde, PackF: SimdField, EvalF: Field + FieldSerde + From + Mul, + IPPackF: SimdField, + IPPackEvalF: SimdField + Mul, T: Transcript, { let mut group = c.benchmark_group(format!( @@ -92,7 +108,7 @@ fn opening_benchmark_helper( .map(|_| EvalF::random_unsafe(&mut rng)) .collect(); - let orion_pp = OrionPCS::::gen_srs_for_testing( + let orion_pp = OrionPCS::::gen_srs_for_testing( &mut rng, &OrionPCSSetup { num_vars, @@ -100,14 +116,15 @@ fn opening_benchmark_helper( }, ); - let commitment_with_data = OrionPCS::::commit(&orion_pp, &random_poly); + let commitment_with_data = + OrionPCS::::commit(&orion_pp, &random_poly); group .bench_function( BenchmarkId::new(format!("{num_vars} variables"), num_vars), |b| { b.iter(|| { - _ = black_box(OrionPCS::::open( + _ = black_box(OrionPCS::::open( &orion_pp, &random_poly, &random_point, @@ -123,12 +140,22 @@ fn opening_benchmark_helper( } fn orion_opening_benchmark(c: &mut Criterion) { - opening_benchmark_helper::>( - c, 19, 25, - ); - opening_benchmark_helper::>( - c, 15, 20, - ); + opening_benchmark_helper::< + GF2, + GF2x64, + GF2_128, + GF2x8, + GF2_128x8, + BytesHashTranscript<_, Keccak256hasher>, + >(c, 19, 25); + opening_benchmark_helper::< + M31, + M31x16, + M31Ext3, + M31x16, + M31Ext3x16, + BytesHashTranscript<_, Keccak256hasher>, + >(c, 15, 20); } criterion_group!(bench, orion_committing_benchmark, orion_opening_benchmark); diff --git a/pcs/src/orion.rs b/pcs/src/orion.rs index 3c608a4a..cd7072b1 100644 --- a/pcs/src/orion.rs +++ b/pcs/src/orion.rs @@ -344,11 +344,38 @@ pub(crate) fn transpose_in_place(mat: &mut [F], scratch: &mut [F], row } #[inline] -pub(crate) fn inner_product(l: &[F0], r: &[F1], mul: impl Fn(&F0, &F1) -> F1) -> F1 +pub(crate) fn simd_inner_prod( + l: &[F0], + r: &[F1], + scratch_pl: &mut [IPPackF0], + scratch_pr: &mut [IPPackF1], +) -> F1 where - F1: std::iter::Sum, + F0: Field, + F1: Field + From + Mul, + IPPackF0: SimdField, + IPPackF1: SimdField + Mul, { - l.iter().zip(r.iter()).map(|(li, ri)| mul(li, ri)).sum() + assert_eq!(l.len() % IPPackF0::PACK_SIZE, 0); + assert_eq!(r.len() % IPPackF1::PACK_SIZE, 0); + + scratch_pl + .iter_mut() + .zip(l.chunks(IPPackF0::PACK_SIZE)) + .for_each(|(pl, ls)| *pl = IPPackF0::pack(ls)); + + scratch_pr + .iter_mut() + .zip(r.chunks(IPPackF1::PACK_SIZE)) + .for_each(|(pr, rs)| *pr = IPPackF1::pack(rs)); + + let simd_sum: IPPackF1 = scratch_pl + .iter() + .zip(scratch_pr.iter()) + .map(|(pl, pr)| *pr * *pl) + .sum(); + + simd_sum.unpack().iter().sum() } /********************************************************** @@ -403,8 +430,6 @@ impl OrionPublicParams { let elems_for_smallest_tree = tree::leaf_adic::() * 2; - // NOTE(Hang): rounding up here in halving the poly variable num - // up to discussion if we want to half by round down let row_num: usize = elems_for_smallest_tree; let msg_size: usize = (1 << poly_variables) / row_num; @@ -470,9 +495,6 @@ impl OrionPublicParams { { let (row_num, msg_size) = Self::row_col_from_variables::(poly.get_num_vars()); - // NOTE: defense programming - rows of alphabet should fit into at least smallest tree - assert!(row_num / PackF::PACK_SIZE * PackF::SIZE >= tree::LEAF_BYTES * 2); - // NOTE: pre transpose evaluations let mut transposed_evaluations = poly.coeffs.clone(); let mut scratch = vec![F::ZERO; 1 << poly.get_num_vars()]; @@ -527,7 +549,7 @@ impl OrionPublicParams { }) } - pub fn open( + pub fn open( &self, poly: &MultiLinearPoly, commitment_with_data: &OrionCommitmentWithData, @@ -538,8 +560,12 @@ impl OrionPublicParams { F: Field + FieldSerde, PackF: SimdField, EvalF: Field + FieldSerde + From + Mul, + IPPackF: SimdField, + IPPackEvalF: SimdField + Mul, T: Transcript, { + assert_eq!(IPPackEvalF::PACK_SIZE, IPPackF::PACK_SIZE); + let (row_num, msg_size) = Self::row_col_from_variables::(poly.get_num_vars()); let num_of_vars_in_codeword = log2(msg_size) as usize; @@ -549,6 +575,10 @@ impl OrionPublicParams { transpose_in_place(&mut transposed_evaluations, &mut scratch, row_num); drop(scratch); + // NOTE: prepare scratch space for both evals and proximity test + let mut scratch_pf = vec![IPPackF::ZERO; row_num / IPPackF::PACK_SIZE]; + let mut scratch_pef = vec![IPPackEvalF::ZERO; row_num / IPPackEvalF::PACK_SIZE]; + // NOTE: working on evaluation response of tensor code IOP based PCS let eq_linear_comb = EqPolynomial::build_eq_x_r(&point[num_of_vars_in_codeword..]); let mut eval_row = vec![EvalF::ZERO; msg_size]; @@ -556,11 +586,21 @@ impl OrionPublicParams { .chunks(row_num) .zip(eval_row.iter_mut()) .for_each(|(col_i, res_i)| { - *res_i = inner_product(col_i, &eq_linear_comb, |i, ei| *ei * *i); + *res_i = simd_inner_prod(col_i, &eq_linear_comb, &mut scratch_pf, &mut scratch_pef); }); + // NOTE: working on evaluation on top of evaluation response let eq_linear_comb = EqPolynomial::build_eq_x_r(&point[..num_of_vars_in_codeword]); - let eval = inner_product(&eval_row, &eq_linear_comb, |i, ei| *ei * *i); + let mut scratch_msg_sized_0 = vec![IPPackEvalF::ZERO; msg_size / IPPackEvalF::PACK_SIZE]; + let mut scratch_msg_sized_1 = vec![IPPackEvalF::ZERO; msg_size / IPPackEvalF::PACK_SIZE]; + let eval = simd_inner_prod( + &eval_row, + &eq_linear_comb, + &mut scratch_msg_sized_0, + &mut scratch_msg_sized_1, + ); + drop(scratch_msg_sized_0); + drop(scratch_msg_sized_1); // NOTE: draw random linear combination out // and compose proximity response(s) of tensor code IOP based PCS @@ -569,16 +609,21 @@ impl OrionPublicParams { let mut proximity_rows = vec![vec![EvalF::ZERO; msg_size]; proximity_repetitions]; (0..proximity_repetitions).for_each(|rep_i| { - let random_linear_combination = transcript.generate_challenge_field_elements(row_num); + let random_coeffs = transcript.generate_challenge_field_elements(row_num); transposed_evaluations .chunks(row_num) .zip(proximity_rows[rep_i].iter_mut()) .for_each(|(col_i, res_i)| { - *res_i = inner_product(col_i, &random_linear_combination, |i, ei| *ei * *i); + *res_i = + simd_inner_prod(col_i, &random_coeffs, &mut scratch_pf, &mut scratch_pef); }); }); + // NOTE: scratch space for evals and proximity test life cycle finish + drop(scratch_pf); + drop(scratch_pef); + // NOTE: MT opening for point queries let leaf_range = row_num * F::FIELD_SIZE / (tree::LEAF_BYTES * 8); let query_num = self.query_complexity(ORION_PCS_SOUNDNESS_BITS); @@ -606,7 +651,7 @@ impl OrionPublicParams { ) } - pub fn verify( + pub fn verify( &self, commitment: &OrionCommitment, point: &[EvalF], @@ -618,14 +663,25 @@ impl OrionPublicParams { F: Field + FieldSerde, PackF: SimdField, EvalF: Field + FieldSerde + From + Mul, + IPPackF: SimdField, + IPPackEvalF: SimdField + Mul, T: Transcript, { let (row_num, msg_size) = Self::row_col_from_variables::(point.len()); let num_of_vars_in_codeword = log2(msg_size) as usize; // NOTE: working on evaluation response, evaluate the rest of the response - let poly_half_evaled = MultiLinearPoly::new(proof.eval_row.clone()); - let final_eval = poly_half_evaled.evaluate_jolt(&point[..num_of_vars_in_codeword]); + let eq_x_r = EqPolynomial::build_eq_x_r(&point[..num_of_vars_in_codeword]); + let mut scratch_msg_sized_0 = vec![IPPackEvalF::ZERO; msg_size / IPPackEvalF::PACK_SIZE]; + let mut scratch_msg_sized_1 = vec![IPPackEvalF::ZERO; msg_size / IPPackEvalF::PACK_SIZE]; + let final_eval = simd_inner_prod( + &proof.eval_row, + &eq_x_r, + &mut scratch_msg_sized_0, + &mut scratch_msg_sized_1, + ); + drop(scratch_msg_sized_0); + drop(scratch_msg_sized_1); if final_eval != evaluation { return false; } @@ -659,8 +715,10 @@ impl OrionPublicParams { // NOTE: encode the proximity/evaluation responses, // check againts all challenged indices by check alphabets against // linear combined interleaved alphabet + let mut scratch_pf = vec![IPPackF::ZERO; row_num / IPPackF::PACK_SIZE]; + let mut scratch_pef = vec![IPPackEvalF::ZERO; row_num / IPPackEvalF::PACK_SIZE]; + let eq_linear_combination = EqPolynomial::build_eq_x_r(&point[num_of_vars_in_codeword..]); - let mul_ext_f = |i: &F, ei: &EvalF| *ei * *i; random_linear_combinations .iter() .zip(proof.proximity_rows.iter()) @@ -672,7 +730,12 @@ impl OrionPublicParams { .zip(proof.query_openings.iter()) .all(|(&qi, range_path)| { let interleaved_alphabet = range_path.unpack_field_elems::(); - let alphabet = inner_product(&interleaved_alphabet, rl, mul_ext_f); + let alphabet = simd_inner_prod( + &interleaved_alphabet, + rl, + &mut scratch_pf, + &mut scratch_pef, + ); alphabet == codeword[qi] }) } @@ -685,16 +748,20 @@ impl OrionPublicParams { * POLYNOMIAL COMMITMENT TRAIT ALIGNMENT FOR ORION * ***************************************************/ -pub struct OrionPCS +pub struct OrionPCS where F: Field + FieldSerde, PackF: SimdField, EvalF: Field + FieldSerde + From + Mul, + IPPackF: SimdField, + IPPackEvalF: SimdField + Mul, T: Transcript, { _marker_f: PhantomData, _marker_pack_f: PhantomData, _marker_eval_f: PhantomData, + _marker_pack_f0: PhantomData, + _marker_pack_eval_f: PhantomData, _marker_t: PhantomData, } @@ -704,11 +771,14 @@ pub struct OrionPCSSetup { pub code_parameter: OrionCodeParameter, } -impl PolynomialCommitmentScheme for OrionPCS +impl PolynomialCommitmentScheme + for OrionPCS where F: Field + FieldSerde, PackF: SimdField, EvalF: Field + FieldSerde + From + Mul, + IPPackF: SimdField, + IPPackEvalF: SimdField + Mul, T: Transcript, { type PublicParams = OrionPCSSetup; @@ -743,7 +813,12 @@ where commitment_with_data: &Self::CommitmentWithData, transcript: &mut Self::FiatShamirTranscript, ) -> (Self::Eval, Self::OpeningProof) { - proving_key.open(poly, commitment_with_data, opening_point, transcript) + proving_key.open::( + poly, + commitment_with_data, + opening_point, + transcript, + ) } fn verify( @@ -754,7 +829,7 @@ where opening_proof: &Self::OpeningProof, transcript: &mut Self::FiatShamirTranscript, ) -> bool { - verifying_key.verify::( + verifying_key.verify::( commitment, opening_point, evaluation, diff --git a/pcs/src/orion_test.rs b/pcs/src/orion_test.rs index 72e60aa7..1cf4cb57 100644 --- a/pcs/src/orion_test.rs +++ b/pcs/src/orion_test.rs @@ -3,21 +3,29 @@ use std::{marker::PhantomData, ops::Mul}; use arith::{ExtensionField, Field, FieldSerde, SimdField}; use ark_std::{log2, test_rng}; use gf2::{GF2x128, GF2x64, GF2x8, GF2}; -use gf2_128::GF2_128; -use mersenne31::{M31Ext3, M31x16, M31}; +use gf2_128::{GF2_128x8, GF2_128}; +use mersenne31::{M31Ext3, M31Ext3x16, M31x16, M31}; use polynomials::{EqPolynomial, MultiLinearPoly}; use transcript::{BytesHashTranscript, Keccak256hasher, Transcript}; use crate::{ - inner_product, transpose_in_place, OrionCode, OrionCommitment, ORION_CODE_PARAMETER_INSTANCE, + transpose_in_place, OrionCode, OrionCommitment, ORION_CODE_PARAMETER_INSTANCE, ORION_PCS_SOUNDNESS_BITS, }; use super::{OrionCommitmentWithData, OrionPublicParams}; +#[inline] +pub(crate) fn vanilla_inner_prod(l: &[F0], r: &[F1], mul: impl Fn(&F0, &F1) -> F1) -> F1 +where + F1: std::iter::Sum, +{ + l.iter().zip(r.iter()).map(|(li, ri)| mul(li, ri)).sum() +} + fn column_combination(mat: &[F], combination: &[F]) -> Vec { mat.chunks(combination.len()) - .map(|row_i| inner_product(row_i, combination, |a, b| *a * *b)) + .map(|row_i| vanilla_inner_prod(row_i, combination, |a, b| *a * *b)) .collect() } @@ -171,7 +179,7 @@ where .for_each(|p| random_poly_ext_half_evaluated.fix_top_variable(p)); let eq_linear_combination = EqPolynomial::build_eq_x_r(&random_point[..vars_for_col]); - let actual_eval: ExtF = inner_product( + let actual_eval: ExtF = vanilla_inner_prod( &random_poly_ext_half_evaluated.coeffs, &eq_linear_combination, |a, b| *a * *b, @@ -188,11 +196,13 @@ fn test_multilinear_poly_tensor_eval() { }) } -fn test_orion_pcs_open_generics(num_vars: usize) +fn test_orion_pcs_open_generics(num_vars: usize) where F: Field + FieldSerde, EvalF: Field + FieldSerde + From + Mul, PackF: SimdField, + IPPackF: SimdField, + IPPackEvalF: SimdField + Mul, { let mut rng = test_rng(); @@ -212,12 +222,12 @@ where let mut transcript: BytesHashTranscript = BytesHashTranscript::new(); let mut transcript_cloned = transcript.clone(); - let orion_pcs = + let orion_pp = OrionPublicParams::from_random::(num_vars, ORION_CODE_PARAMETER_INSTANCE, &mut rng); - let commit_with_data = orion_pcs.commit::(&random_poly).unwrap(); + let commit_with_data = orion_pp.commit::(&random_poly).unwrap(); - let (_, opening) = orion_pcs.open( + let (_, opening) = orion_pp.open::( &random_poly, &commit_with_data, &random_point, @@ -233,7 +243,7 @@ where assert_eq!(expected_eval, actual_eval); // NOTE: compute evaluation codeword - let eval_codeword = orion_pcs.code_instance.encode(&opening.eval_row).unwrap(); + let eval_codeword = orion_pp.code_instance.encode(&opening.eval_row).unwrap(); let eq_linear_combination = EqPolynomial::build_eq_x_r(&random_point[vars_for_col..]); let mut interleaved_codeword_ext = commit_with_data .interleaved_alphabet_tree @@ -241,7 +251,7 @@ where .iter() .map(|&f| EvalF::from(f)) .collect::>(); - interleaved_codeword_ext.resize(row_num * orion_pcs.code_len(), EvalF::ZERO); + interleaved_codeword_ext.resize(row_num * orion_pp.code_len(), EvalF::ZERO); let eq_combined_codeword = column_combination(&interleaved_codeword_ext, &eq_linear_combination); @@ -249,7 +259,7 @@ where // NOTE: compute proximity codewords let proximity_repetitions = - orion_pcs.proximity_repetition_num(ORION_PCS_SOUNDNESS_BITS, EvalF::FIELD_SIZE); + orion_pp.proximity_repetition_num(ORION_PCS_SOUNDNESS_BITS, EvalF::FIELD_SIZE); assert_eq!(proximity_repetitions, opening.proximity_rows.len()); opening.proximity_rows.iter().for_each(|proximity_row| { @@ -259,7 +269,7 @@ where let expected_proximity_codeword = column_combination(&interleaved_codeword_ext, &random_linear_combination); - let actual_proximity_codeword = orion_pcs.code_instance.encode(proximity_row).unwrap(); + let actual_proximity_codeword = orion_pp.code_instance.encode(proximity_row).unwrap(); assert_eq!(expected_proximity_codeword, actual_proximity_codeword) }); @@ -267,15 +277,21 @@ where #[test] fn test_orion_pcs_open() { - (19..=25).for_each(|num_vars| test_orion_pcs_open_generics::(num_vars)); - (9..=15).for_each(|num_vars| test_orion_pcs_open_generics::(num_vars)) + (19..=25).for_each(|num_vars| { + test_orion_pcs_open_generics::(num_vars) + }); + (9..=15).for_each(|num_vars| { + test_orion_pcs_open_generics::(num_vars) + }) } -fn test_orion_pcs_full_e2e_generics(num_vars: usize) +fn test_orion_pcs_full_e2e_generics(num_vars: usize) where F: Field + FieldSerde, EvalF: Field + FieldSerde + Mul + From, PackF: SimdField, + IPPackF: SimdField, + IPPackEvalF: SimdField + Mul, { let mut rng = test_rng(); @@ -301,14 +317,14 @@ where let commit_with_data = orion_pp.commit::(&random_poly).unwrap(); - let (_, opening) = orion_pp.open( + let (_, opening) = orion_pp.open::( &random_poly, &commit_with_data, &random_point, &mut transcript, ); - assert!(orion_pp.verify::( + assert!(orion_pp.verify::( &commit_with_data.into(), &random_point, expected_eval, @@ -320,12 +336,12 @@ where #[test] fn test_orion_pcs_full_e2e() { (19..=25).for_each(|num_vars| { - test_orion_pcs_full_e2e_generics::(num_vars); - test_orion_pcs_full_e2e_generics::(num_vars); - test_orion_pcs_full_e2e_generics::(num_vars); + test_orion_pcs_full_e2e_generics::(num_vars); + test_orion_pcs_full_e2e_generics::(num_vars); + test_orion_pcs_full_e2e_generics::(num_vars); }); (9..=15).for_each(|num_vars| { - test_orion_pcs_full_e2e_generics::(num_vars); - test_orion_pcs_full_e2e_generics::(num_vars); + test_orion_pcs_full_e2e_generics::(num_vars); + test_orion_pcs_full_e2e_generics::(num_vars); }) }