From d1af0d840e79f8d4f1e38082823ad12ac323971d Mon Sep 17 00:00:00 2001 From: cgouert Date: Wed, 1 May 2024 17:30:14 -0400 Subject: [PATCH] Optimize encrypted correlation baseline and start adding timings for primitive operations Co-authored-by: Dimitris Mouris --- Cargo.toml | 10 +- ...correlation_haar.rs => correlation_lut.rs} | 85 +++++++--- src/correlation_ptxt.rs | 4 +- src/primitive_ops.rs | 145 ++++++++++++++++++ 4 files changed, 221 insertions(+), 23 deletions(-) rename src/{correlation_haar.rs => correlation_lut.rs} (56%) create mode 100644 src/primitive_ops.rs diff --git a/Cargo.toml b/Cargo.toml index 275550b..dcff35a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,11 +66,17 @@ name = "correlation_ptxt" path = "src/correlation_ptxt.rs" [[bin]] -name = "correlation_haar" -path = "src/correlation_haar.rs" +name = "correlation_lut" +path = "src/correlation_lut.rs" # Euclidean Distance [[bin]] name = "euclidean" path = "src/euclidean.rs" + +# Primitive Timings + +[[bin]] +name = "primitive_ops" +path = "src/primitive_ops.rs" diff --git a/src/correlation_haar.rs b/src/correlation_lut.rs similarity index 56% rename from src/correlation_haar.rs rename to src/correlation_lut.rs index 354bc6e..99b0951 100644 --- a/src/correlation_haar.rs +++ b/src/correlation_lut.rs @@ -17,11 +17,11 @@ fn main() { let salaries = &data[1]; let dataset_size = salaries.len() as f64; - let mut salary_sorted = salaries.clone(); - salary_sorted.sort(); + let mut salaries_sorted = salaries.clone(); + salaries_sorted.sort(); // ------- Client side ------- // - let bit_width = 16; + let bit_width = 20; // Number of blocks per ciphertext let nb_blocks = bit_width / 2; @@ -30,6 +30,9 @@ fn main() { nb_blocks ); + // Scale factor + let scale = 100_000_u64; + let start = Instant::now(); // Generate radix keys let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks); @@ -45,7 +48,7 @@ fn main() { ); let start = Instant::now(); - let encrypted_salaries: Vec<_> = salary_sorted + let encrypted_salaries: Vec<_> = salaries_sorted .par_iter() // Use par_iter() for parallel iteration .map(|&salary| client_key.encrypt(salary)) .collect(); @@ -55,12 +58,8 @@ fn main() { ); // ------- Server side ------- // - // TODO: Move LUT gens up here - - // Compute the encrypted correlation - let start = Instant::now(); - println!("Starting computing mean"); + // The experience vector is known to the server. let experience_mean = experience.iter().map(|&exp| exp as f64).sum::() / dataset_size; let experience_variance: f64 = experience .iter() @@ -68,21 +67,51 @@ fn main() { .sum(); let experience_stddev = experience_variance.sqrt(); + // Offline: LUT genaration is offline cost. + let mut dummy_ct: RadixCiphertext = server_key.create_trivial_radix(0_u64, nb_blocks); + let dummy_ct_2 = server_key.create_trivial_radix(0_u64, nb_blocks); + for _ in 0..encrypted_salaries.len() { + dummy_ct = server_key.add_parallelized(&dummy_ct, &dummy_ct_2); + } + let div_lut = wopbs_key.generate_lut_radix(&dummy_ct, |x: u64| x / (dataset_size as u64)); + dummy_ct = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy_ct); + dummy_ct = wopbs_key.wopbs(&dummy_ct, &div_lut); + dummy_ct = wopbs_key.keyswitch_to_pbs_params(&dummy_ct); + let mut dummy_acc: RadixCiphertext = server_key.create_trivial_radix(0_u64, nb_blocks); + for _ in 0..encrypted_salaries.len() { + let mut dummy_ct_3 = server_key.sub_parallelized(&dummy_ct_2, &dummy_ct); + dummy_ct_3 = server_key.mul_parallelized(&dummy_ct_3, &dummy_ct_3); + dummy_acc = server_key.add_parallelized(&dummy_acc, &dummy_ct_3); + } + let sqrt_lut = wopbs_key.generate_lut_radix(&dummy_acc, |x: u64| { + if x == 0 { + 1 // avoid division with zero error. + } else { + scale / (x.sqrt() * experience_stddev as u64) + } + }); + + // Online: Compute the encrypted correlation + + // The salaries vector is encrypted. + println!("- Starting computing mean"); + let start = Instant::now(); + let total = start; let mut salaries_sum_enc = encrypted_salaries.iter().fold( server_key.create_trivial_radix(0_u64, nb_blocks), |acc: RadixCiphertext, salary| server_key.add_parallelized(&acc, salary), ); - let div_lut = - wopbs_key.generate_lut_radix(&salaries_sum_enc, |x: u64| x / (dataset_size as u64)); salaries_sum_enc = wopbs_key.keyswitch_to_wopbs_params(&server_key, &salaries_sum_enc); let mut salaries_mean_enc = wopbs_key.wopbs(&salaries_sum_enc, &div_lut); salaries_mean_enc = wopbs_key.keyswitch_to_pbs_params(&salaries_mean_enc); - println!( - "Finished computing mean in {:?} sec.", + "- Finished computing mean in {:?} sec.", start.elapsed().as_secs_f64() ); + // Cov = Sum_i^n (salary_i - mean(salary))(experience_i - mean(experience)) + println!("- Starting computing covariance"); + let start = Instant::now(); let covariance = encrypted_salaries .iter() .zip(experience.iter()) @@ -94,7 +123,14 @@ fn main() { server_key.create_trivial_radix(0_u64, nb_blocks), |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), ); + println!( + "- Finished computing covariance in {:?} sec.", + start.elapsed().as_secs_f64() + ); + // Var_salary = Sum_i^n (salary_i - mean(salary))^2 + println!("- Starting computing variance"); + let start = Instant::now(); let mut salaries_variance_enc = encrypted_salaries .iter() .map(|salary_enc| { @@ -105,20 +141,31 @@ fn main() { server_key.create_trivial_radix(0_u64, nb_blocks), |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), ); + println!( + "- Finished computing variance in {:?} sec.", + start.elapsed().as_secs_f64() + ); - let sqrt_lut = wopbs_key.generate_lut_radix(&salaries_variance_enc, |x: u64| x.sqrt()); - + // sigma_salary (or stddev) = sqrt(var_salary) salaries_variance_enc = wopbs_key.keyswitch_to_wopbs_params(&server_key, &salaries_variance_enc); + println!("- Starting computing LUT"); + let start = Instant::now(); let mut salaries_stddev_enc = wopbs_key.wopbs(&salaries_variance_enc, &sqrt_lut); salaries_stddev_enc = wopbs_key.keyswitch_to_pbs_params(&salaries_stddev_enc); + println!( + "- Finished computing LUT in {:?} sec.", + start.elapsed().as_secs_f64() + ); + let correlation_enc = server_key.mul_parallelized(&salaries_stddev_enc, &covariance); - let divisor_enc = - server_key.scalar_mul_parallelized(&salaries_stddev_enc, experience_stddev as u32); - let correlation_enc = server_key.div_parallelized(&covariance, &divisor_enc); + println!( + "Finished computing correlation in {:?} sec.", + total.elapsed().as_secs_f64() + ); // ------- Client side ------- // let correlation: u64 = client_key.decrypt(&correlation_enc); - println!("correlation: {}", correlation); + println!("Correlation: {}", (correlation as f64) / (scale as f64)); } diff --git a/src/correlation_ptxt.rs b/src/correlation_ptxt.rs index 97420a5..66650fc 100644 --- a/src/correlation_ptxt.rs +++ b/src/correlation_ptxt.rs @@ -6,7 +6,7 @@ fn pearson_correlation(x: &[u32], y: &[u32]) -> f64 { let x_mean = x.iter().map(|&xi| xi as f64).sum::() / n; let y_mean = y.iter().map(|&yi| yi as f64).sum::() / n; - let sum_xy: f64 = x + let covariance: f64 = x .iter() .zip(y.iter()) .map(|(&xi, &yi)| ((xi as f64) - x_mean) * ((yi as f64) - y_mean)) @@ -14,7 +14,7 @@ fn pearson_correlation(x: &[u32], y: &[u32]) -> f64 { let variance_x: f64 = x.iter().map(|&xi| ((xi as f64) - x_mean).powi(2)).sum(); let variance_y: f64 = y.iter().map(|&yi| ((yi as f64) - y_mean).powi(2)).sum(); - sum_xy / (variance_x.sqrt() * variance_y.sqrt()) + covariance / (variance_x.sqrt() * variance_y.sqrt()) } fn main() { diff --git a/src/primitive_ops.rs b/src/primitive_ops.rs new file mode 100644 index 0000000..3b79d47 --- /dev/null +++ b/src/primitive_ops.rs @@ -0,0 +1,145 @@ +use std::time::Instant; + +use dwt::{transform, wavelet::Haar, Operation}; +use ripple::common::*; +use tfhe::{ + integer::{ + gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, + }, + shortint::parameters::{ + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + }, +}; + +pub fn haar_square( + table_size: u8, + input_precision: u8, + output_precision: u8, + bit_width: u8, +) -> (Vec, Vec) { + let max = 1 << bit_width; + let mut data = Vec::new(); + for x in 0..max { + let x = unquantize(x, input_precision, bit_width); + let square = x * x; + data.push(square); + } + data.rotate_right(1 << (bit_width - 1)); + transform( + &mut data, + Operation::Forward, + &Haar::new(), + (bit_width - table_size) as usize, + ); + let coef_len = 1 << table_size; + let scalar = 2f64.powf(-((bit_width - table_size) as f64) / 2f64); + let mut haar: Vec = data + .get(0..coef_len) + .unwrap() + .iter() + .map(|x| quantize(scalar * x, output_precision, bit_width)) + .collect(); + haar.rotate_right(1 << (table_size - 1)); + let mask = (1 << (bit_width / 2)) - 1; + let lsb = haar.iter().map(|x| x & mask).collect(); + let msb = haar.iter().map(|x| x >> (bit_width / 2) & mask).collect(); + (lsb, msb) +} + +fn eval_lut(x: u64, lut_map: &Vec) -> u64 { + lut_map[x as usize] +} + +fn main() { + // ------- Client side ------- // + let bit_width = 20; + + // Number of blocks per ciphertext + let nb_blocks = bit_width / 2; + println!( + "Number of blocks for the radix decomposition: {:?}", + nb_blocks + ); + + let start = Instant::now(); + // Generate radix keys + let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks); + // Generate key for PBS (without padding) + let wopbs_key = WopbsKey::new_wopbs_key( + &client_key, + &server_key, + &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + ); + println!( + "Key generation done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + let x = 5_u64; + // let y = 10_u64; + let x_ct = client_key.encrypt(x); + // let y_ct = client_key.encrypt(y); + + // ------- Server side ------- // + + // 1. Square + println!("\n1. Square"); + + // 1.1. Square using multiplication + let start = Instant::now(); + let square_ct = server_key.mul_parallelized(&x_ct, &x_ct); + println!("Ct-Ct Mult in {:?} sec.", start.elapsed().as_secs_f64()); + let prod: u64 = client_key.decrypt(&square_ct); + + // 1.2. Square using LUT + let square_lut = wopbs_key.generate_lut_radix(&x_ct, |x: u64| x * x); + let start = Instant::now(); + let x_ct_ks = wopbs_key.keyswitch_to_wopbs_params(&server_key, &x_ct); + let mut square_ct = wopbs_key.wopbs(&x_ct_ks, &square_lut); + square_ct = wopbs_key.keyswitch_to_pbs_params(&square_ct); + println!("LUT Square in {:?} sec.", start.elapsed().as_secs_f64()); + let lut_prod: u64 = client_key.decrypt(&square_ct); + + // 1.3. Square using Haar DWT LUT + let (haar_lsb, haar_msb) = haar_square((bit_width >> 1) as u8, 8_u8, 16_u8, bit_width as u8); + dbg!(&haar_lsb); + dbg!(&haar_msb); + let dummy: RadixCiphertext = server_key.create_trivial_radix(0_u64, nb_blocks); + let dummy_blocks = &dummy.into_blocks()[(nb_blocks >> 1)..nb_blocks]; + let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks.to_vec()); + let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1); + let haar_lsb_lut = wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut(x, &haar_lsb)); + let haar_msb_lut = wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut(x, &haar_msb)); + + let start = Instant::now(); + // Truncate x + let x_truncated_blocks = &x_ct.into_blocks()[(nb_blocks >> 1)..nb_blocks]; + let x_truncated = RadixCiphertext::from_blocks(x_truncated_blocks.to_vec()); + let x_truncated = server_key.scalar_add_parallelized(&x_truncated, 1); + let x_truncated_ks = wopbs_key.keyswitch_to_wopbs_params(&server_key, &x_truncated); + let (square_lsb, square_msb) = rayon::join( + || { + let square_lsb = wopbs_key.wopbs(&x_truncated_ks, &haar_lsb_lut); + wopbs_key.keyswitch_to_pbs_params(&square_lsb) + }, + || { + let square_msb = wopbs_key.wopbs(&x_truncated_ks, &haar_msb_lut); + wopbs_key.keyswitch_to_pbs_params(&square_msb) + }, + ); + let mut square_lsb_blocks = square_lsb.into_blocks(); + square_lsb_blocks.extend(square_msb.into_blocks()); + let square_ct_haar = RadixCiphertext::from_blocks(square_lsb_blocks.to_vec()); + + println!( + "Haar LUT Square in {:?} sec.", + start.elapsed().as_secs_f64() + ); + let dwt_lut_prod: u64 = client_key.decrypt(&square_ct_haar); + + println!( + "--- Exact: {:?}, LUT: {:?}, DWT LUT: {:?}", + prod, lut_prod, dwt_lut_prod + ); +}