From 82e1b9793faab1b884cb8ec6be0ee272a135e275 Mon Sep 17 00:00:00 2001 From: cgouert Date: Sun, 26 May 2024 00:18:08 -0400 Subject: [PATCH] Add all correlation coefficient variants --- Cargo.toml | 12 +++ src/correlation_bior.rs | 184 +++++++++++++++++++++++++++++++++++ src/correlation_haar.rs | 177 +++++++++++++++++++++++++++++++++ src/correlation_lut.rs | 79 +++++---------- src/correlation_quantized.rs | 170 ++++++++++++++++++++++++++++++++ 5 files changed, 570 insertions(+), 52 deletions(-) create mode 100644 src/correlation_bior.rs create mode 100644 src/correlation_haar.rs create mode 100644 src/correlation_quantized.rs diff --git a/Cargo.toml b/Cargo.toml index ff789a7..3cdd7ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,18 @@ path = "src/correlation_ptxt.rs" name = "correlation_lut" path = "src/correlation_lut.rs" +[[bin]] +name = "correlation_quantized" +path = "src/correlation_quantized.rs" + +[[bin]] +name = "correlation_haar" +path = "src/correlation_haar.rs" + +[[bin]] +name = "correlation_bior" +path = "src/correlation_bior.rs" + # Euclidean Distance [[bin]] diff --git a/src/correlation_bior.rs b/src/correlation_bior.rs new file mode 100644 index 0000000..6f31e5a --- /dev/null +++ b/src/correlation_bior.rs @@ -0,0 +1,184 @@ +use std::time::Instant; + +use rayon::prelude::*; +use ripple::common::*; +use tfhe::{ + integer::{ + gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, + }, + shortint::parameters::{ + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + }, +}; + +fn main() { + let data = read_csv("data/correlation.csv"); + let experience = &data[0]; + let salaries = &data[1]; + let dataset_size = salaries.len() as f64; + + let mut salaries_sorted = salaries.clone(); + salaries_sorted.sort(); + + // ------- Client side ------- // + let bit_width = 16; + let precision = 12; + let wave_depth = 8; + + // TODO: Replace with actual custom functions + let (sqrt_lut_lsb, sqrt_lut_msb) = + bior("data/bior_lut_sqrt_16.json", wave_depth as u8, bit_width); + let (sqrt_lut_lsb_2, sqrt_lut_msb_2) = + bior("data/bior_lut_sqrt_16_2.json", wave_depth as u8, bit_width); + let (div_lut_lsb, div_lut_msb) = bior("data/bior_lut_div_16.json", wave_depth as u8, bit_width); + let (div_lut_lsb_2, div_lut_msb_2) = + bior("data/bior_lut_div_16_2.json", wave_depth as u8, bit_width); + + let sqrt_luts = vec![ + &sqrt_lut_lsb, + &sqrt_lut_lsb_2, + &sqrt_lut_msb, + &sqrt_lut_msb_2, + ]; + let div_luts = vec![&div_lut_lsb, &div_lut_lsb_2, &div_lut_msb, &div_lut_msb_2]; + + // Number of blocks per ciphertext + let nb_blocks = bit_width / 2; + println!( + "Number of blocks for the radix decomposition: {:?}", + nb_blocks + ); + + // Scale factor + let scale = 100_000_u64; + + let start = Instant::now(); + // Generate radix keys + let (client_key, server_key) = + gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks as usize); + // Generate key for PBS (without padding) + let wopbs_key = WopbsKey::new_wopbs_key( + &client_key, + &server_key, + &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + ); + println!( + "Key generation done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + let start = Instant::now(); + let encrypted_salaries: Vec<_> = salaries_sorted + .par_iter() // Use par_iter() for parallel iteration + .map(|&salary| client_key.encrypt(quantize(salary as f64, precision, bit_width as u8))) + .collect(); + println!( + "Encryption done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + // ------- Server side ------- // + + // The experience vector is known to the server. + let experience_mean = experience.iter().map(|&exp| exp as f64).sum::() / dataset_size; + // let experience_variance: f64 = experience + // .iter() + // .map(|&exp| ((exp as f64) - experience_mean).powi(2)) + // .sum(); + // let experience_stddev = experience_variance.sqrt(); + + // Offline: LUT genaration is offline cost. + let lut_gen_start = Instant::now(); + println!("Generating LUT."); + let mut dummy: RadixCiphertext = + server_key.create_trivial_radix(2_u64, (nb_blocks >> 1).into()); + dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy); + let mut dummy_blocks = dummy.clone().into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + dummy = RadixCiphertext::from_blocks(dummy_blocks); + let encoded_sqrt_luts = sqrt_luts + .iter() + .map(|lut| wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &lut.to_vec()))) + .collect::>(); + let encoded_div_luts = div_luts + .iter() + .map(|lut| wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &lut.to_vec()))) + .collect::>(); + println!( + "LUT generation done in {:?} sec.", + lut_gen_start.elapsed().as_secs_f64() + ); + // Online: Compute the encrypted correlation + + // The salaries vector is encrypted. + let start = Instant::now(); + let total = start; + let salaries_sum_enc = encrypted_salaries.iter().fold( + server_key.create_trivial_radix(0_u64, nb_blocks as usize), + |acc: RadixCiphertext, salary| server_key.add_parallelized(&acc, salary), + ); + // println!("salaries_sum_enc degree: {:?}", salaries_sum_enc.blocks()[0].degree); + let salaries_mean_enc = ct_lut_eval_bior_no_gen( + salaries_sum_enc, + bit_width as usize, + nb_blocks as usize, + wave_depth, + &wopbs_key, + 0_i32, + &server_key, + &encoded_div_luts, + ); + + // Cov = Sum_i^n (salary_i - mean(salary))(experience_i - mean(experience)) + let covariance = encrypted_salaries + .iter() + .zip(experience.iter()) + .map(|(salary_enc, &exp)| { + let x = server_key.sub_parallelized(salary_enc, &salaries_mean_enc); + server_key.scalar_mul_parallelized(&x, exp - (experience_mean as u32)) + }) + .fold( + server_key.create_trivial_radix(0_u64, nb_blocks as usize), + |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), + ); + + // Var_salary = Sum_i^n (salary_i - mean(salary))^2 + let salaries_variance_enc = encrypted_salaries + .iter() + .map(|salary_enc| { + let x = server_key.sub_parallelized(salary_enc, &salaries_mean_enc); + server_key.mul_parallelized(&x, &x) + }) + .fold( + server_key.create_trivial_radix(0_u64, nb_blocks as usize), + |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), + ); + + // sigma_salary (or stddev) = sqrt(var_salary) + // println!("salaries_variance_enc degree: {:?}", salaries_variance_enc.blocks()[0].degree); + let salaries_stddev_enc = ct_lut_eval_bior_no_gen( + salaries_variance_enc, + bit_width as usize, + nb_blocks as usize, + wave_depth, + &wopbs_key, + 0_i32, + &server_key, + &encoded_sqrt_luts, + ); + let correlation_enc = server_key.mul_parallelized(&salaries_stddev_enc, &covariance); + + println!( + "Finished computing correlation in {:?} sec.", + total.elapsed().as_secs_f64() + ); + + // ------- Client side ------- // + let correlation: u64 = client_key.decrypt(&correlation_enc); + let correlation_final: f64 = unquantize(correlation, precision, bit_width as u8); + + println!("Correlation: {}", correlation_final / (scale as f64)); +} diff --git a/src/correlation_haar.rs b/src/correlation_haar.rs new file mode 100644 index 0000000..572198e --- /dev/null +++ b/src/correlation_haar.rs @@ -0,0 +1,177 @@ +use std::time::Instant; + +use rayon::prelude::*; +use ripple::common::*; +use tfhe::{ + integer::{ + gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, + }, + shortint::parameters::{ + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + }, +}; + +fn main() { + let data = read_csv("data/correlation.csv"); + let experience = &data[0]; + let salaries = &data[1]; + let dataset_size = salaries.len() as f64; + + let mut salaries_sorted = salaries.clone(); + salaries_sorted.sort(); + + // ------- Client side ------- // + let bit_width = 16; + let precision = 6; + + // Number of blocks per ciphertext + let nb_blocks = bit_width / 2; + println!( + "Number of blocks for the radix decomposition: {:?}", + nb_blocks + ); + + // Scale factor + let scale = 100_000_u64; + + let start = Instant::now(); + // Generate radix keys + let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks); + // Generate key for PBS (without padding) + let wopbs_key = WopbsKey::new_wopbs_key( + &client_key, + &server_key, + &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + ); + println!( + "Key generation done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + let start = Instant::now(); + let encrypted_salaries: Vec<_> = salaries_sorted + .par_iter() // Use par_iter() for parallel iteration + .map(|&salary| client_key.encrypt(quantize(salary as f64, precision, bit_width as u8))) + .collect(); + println!( + "Encryption done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + // ------- Server side ------- // + + // The experience vector is known to the server. + let experience_mean = experience.iter().map(|&exp| exp as f64).sum::() / dataset_size; + let experience_variance: f64 = experience + .iter() + .map(|&exp| ((exp as f64) - experience_mean).powi(2)) + .sum(); + let experience_stddev = experience_variance.sqrt(); + + // Offline: LUT genaration is offline cost. + let lut_gen_start = Instant::now(); + println!("Generating LUT."); + let mut dummy: RadixCiphertext = + server_key.create_trivial_radix(2_u64, (nb_blocks >> 1).into()); + dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy); + let mut dummy_blocks = dummy.clone().into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + dummy = RadixCiphertext::from_blocks(dummy_blocks); + let (haar_lsb, haar_msb) = haar( + precision, + precision, + bit_width as u8, + bit_width as u8, + &|x: f64| { + if x.abs() < 0.05 { + 1.0 // avoid division with zero error. + } else { + scale as f64 / (x.sqrt() * experience_stddev) + } + }, + ); + let haar_lsb_lut_sqrt = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_lsb)); + let haar_msb_lut_sqrt = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_msb)); + let (haar_lsb, haar_msb) = haar( + precision, + precision, + bit_width as u8, + bit_width as u8, + &|x: f64| x / dataset_size, + ); + let haar_lsb_lut_div = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_lsb)); + let haar_msb_lut_div = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_msb)); + println!( + "LUT generation done in {:?} sec.", + lut_gen_start.elapsed().as_secs_f64() + ); + // Online: Compute the encrypted correlation + + // The salaries vector is encrypted. + let start = Instant::now(); + let total = start; + let salaries_sum_enc = encrypted_salaries.iter().fold( + server_key.create_trivial_radix(0_u64, nb_blocks), + |acc: RadixCiphertext, salary| server_key.add_parallelized(&acc, salary), + ); + // println!("salaries_sum_enc degree: {:?}", salaries_sum_enc.blocks()[0].degree); + let salaries_mean_enc = ct_lut_eval_haar_no_gen( + salaries_sum_enc, + nb_blocks, + &wopbs_key, + &server_key, + &haar_lsb_lut_div, + &haar_msb_lut_div, + ); + + // Cov = Sum_i^n (salary_i - mean(salary))(experience_i - mean(experience)) + let covariance = encrypted_salaries + .iter() + .zip(experience.iter()) + .map(|(salary_enc, &exp)| { + let x = server_key.sub_parallelized(salary_enc, &salaries_mean_enc); + server_key.scalar_mul_parallelized(&x, exp - (experience_mean as u32)) + }) + .fold( + server_key.create_trivial_radix(0_u64, nb_blocks), + |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), + ); + + // Var_salary = Sum_i^n (salary_i - mean(salary))^2 + let salaries_variance_enc = encrypted_salaries + .iter() + .map(|salary_enc| { + let x = server_key.sub_parallelized(salary_enc, &salaries_mean_enc); + server_key.mul_parallelized(&x, &x) + }) + .fold( + server_key.create_trivial_radix(0_u64, nb_blocks), + |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), + ); + + // sigma_salary (or stddev) = sqrt(var_salary) + // println!("salaries_variance_enc degree: {:?}", salaries_variance_enc.blocks()[0].degree); + let salaries_stddev_enc = ct_lut_eval_haar_no_gen( + salaries_variance_enc, + nb_blocks, + &wopbs_key, + &server_key, + &haar_lsb_lut_sqrt, + &haar_msb_lut_sqrt, + ); + let correlation_enc = server_key.mul_parallelized(&salaries_stddev_enc, &covariance); + + println!( + "Finished computing correlation in {:?} sec.", + total.elapsed().as_secs_f64() + ); + + // ------- Client side ------- // + let correlation: u64 = client_key.decrypt(&correlation_enc); + let correlation_final: f64 = unquantize(correlation, precision, bit_width as u8); + + println!("Correlation: {}", correlation_final / (scale as f64)); +} diff --git a/src/correlation_lut.rs b/src/correlation_lut.rs index 99b0951..089e3dc 100644 --- a/src/correlation_lut.rs +++ b/src/correlation_lut.rs @@ -2,11 +2,13 @@ use std::time::Instant; use num_integer::Roots; use rayon::prelude::*; -use ripple::common; +use ripple::common::{self, ct_lut_eval_no_gen}; use tfhe::{ - integer::{gen_keys_radix, wopbs::*, RadixCiphertext}, + integer::{ + gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, + }, shortint::parameters::{ - parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, PARAM_MESSAGE_2_CARRY_2_KS_PBS, }, }; @@ -21,7 +23,7 @@ fn main() { salaries_sorted.sort(); // ------- Client side ------- // - let bit_width = 20; + let bit_width = 16; // Number of blocks per ciphertext let nb_blocks = bit_width / 2; @@ -68,50 +70,40 @@ fn main() { let experience_stddev = experience_variance.sqrt(); // Offline: LUT genaration is offline cost. - let mut dummy_ct: RadixCiphertext = server_key.create_trivial_radix(0_u64, nb_blocks); - let dummy_ct_2 = server_key.create_trivial_radix(0_u64, nb_blocks); - for _ in 0..encrypted_salaries.len() { - dummy_ct = server_key.add_parallelized(&dummy_ct, &dummy_ct_2); - } - let div_lut = wopbs_key.generate_lut_radix(&dummy_ct, |x: u64| x / (dataset_size as u64)); - dummy_ct = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy_ct); - dummy_ct = wopbs_key.wopbs(&dummy_ct, &div_lut); - dummy_ct = wopbs_key.keyswitch_to_pbs_params(&dummy_ct); - let mut dummy_acc: RadixCiphertext = server_key.create_trivial_radix(0_u64, nb_blocks); - for _ in 0..encrypted_salaries.len() { - let mut dummy_ct_3 = server_key.sub_parallelized(&dummy_ct_2, &dummy_ct); - dummy_ct_3 = server_key.mul_parallelized(&dummy_ct_3, &dummy_ct_3); - dummy_acc = server_key.add_parallelized(&dummy_acc, &dummy_ct_3); + let lut_gen_start = Instant::now(); + println!("Generating LUT."); + let mut dummy: RadixCiphertext = server_key.create_trivial_radix(2_u64, (nb_blocks).into()); + dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy); + let mut dummy_blocks = dummy.clone().into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); } - let sqrt_lut = wopbs_key.generate_lut_radix(&dummy_acc, |x: u64| { + dummy = RadixCiphertext::from_blocks(dummy_blocks); + let sqrt_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| { if x == 0 { 1 // avoid division with zero error. } else { scale / (x.sqrt() * experience_stddev as u64) } }); - + let div_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| x / (dataset_size as u64)); + println!( + "LUT generation done in {:?} sec.", + lut_gen_start.elapsed().as_secs_f64() + ); // Online: Compute the encrypted correlation // The salaries vector is encrypted. - println!("- Starting computing mean"); let start = Instant::now(); let total = start; - let mut salaries_sum_enc = encrypted_salaries.iter().fold( + let salaries_sum_enc = encrypted_salaries.iter().fold( server_key.create_trivial_radix(0_u64, nb_blocks), |acc: RadixCiphertext, salary| server_key.add_parallelized(&acc, salary), ); - salaries_sum_enc = wopbs_key.keyswitch_to_wopbs_params(&server_key, &salaries_sum_enc); - let mut salaries_mean_enc = wopbs_key.wopbs(&salaries_sum_enc, &div_lut); - salaries_mean_enc = wopbs_key.keyswitch_to_pbs_params(&salaries_mean_enc); - println!( - "- Finished computing mean in {:?} sec.", - start.elapsed().as_secs_f64() - ); + // println!("salaries_sum_enc degree: {:?}", salaries_sum_enc.blocks()[0].degree); + let salaries_mean_enc = ct_lut_eval_no_gen(salaries_sum_enc, &wopbs_key, &server_key, &div_lut); // Cov = Sum_i^n (salary_i - mean(salary))(experience_i - mean(experience)) - println!("- Starting computing covariance"); - let start = Instant::now(); let covariance = encrypted_salaries .iter() .zip(experience.iter()) @@ -123,15 +115,9 @@ fn main() { server_key.create_trivial_radix(0_u64, nb_blocks), |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), ); - println!( - "- Finished computing covariance in {:?} sec.", - start.elapsed().as_secs_f64() - ); // Var_salary = Sum_i^n (salary_i - mean(salary))^2 - println!("- Starting computing variance"); - let start = Instant::now(); - let mut salaries_variance_enc = encrypted_salaries + let salaries_variance_enc = encrypted_salaries .iter() .map(|salary_enc| { let x = server_key.sub_parallelized(salary_enc, &salaries_mean_enc); @@ -141,22 +127,11 @@ fn main() { server_key.create_trivial_radix(0_u64, nb_blocks), |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), ); - println!( - "- Finished computing variance in {:?} sec.", - start.elapsed().as_secs_f64() - ); // sigma_salary (or stddev) = sqrt(var_salary) - salaries_variance_enc = - wopbs_key.keyswitch_to_wopbs_params(&server_key, &salaries_variance_enc); - println!("- Starting computing LUT"); - let start = Instant::now(); - let mut salaries_stddev_enc = wopbs_key.wopbs(&salaries_variance_enc, &sqrt_lut); - salaries_stddev_enc = wopbs_key.keyswitch_to_pbs_params(&salaries_stddev_enc); - println!( - "- Finished computing LUT in {:?} sec.", - start.elapsed().as_secs_f64() - ); + // println!("salaries_variance_enc degree: {:?}", salaries_variance_enc.blocks()[0].degree); + let salaries_stddev_enc = + ct_lut_eval_no_gen(salaries_variance_enc, &wopbs_key, &server_key, &sqrt_lut); let correlation_enc = server_key.mul_parallelized(&salaries_stddev_enc, &covariance); println!( diff --git a/src/correlation_quantized.rs b/src/correlation_quantized.rs new file mode 100644 index 0000000..a2ee480 --- /dev/null +++ b/src/correlation_quantized.rs @@ -0,0 +1,170 @@ +use std::time::Instant; + +use num_integer::Roots; +use rayon::prelude::*; +use ripple::common::*; +use tfhe::{ + integer::{ + ciphertext::BaseRadixCiphertext, gen_keys_radix, wopbs::*, IntegerCiphertext, + IntegerRadixCiphertext, RadixCiphertext, + }, + shortint::{ + parameters::{ + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + }, + Ciphertext, + }, +}; + +fn main() { + let data = read_csv("data/correlation.csv"); + let experience = &data[0]; + let salaries = &data[1]; + let dataset_size = salaries.len() as f64; + + let mut salaries_sorted = salaries.clone(); + salaries_sorted.sort(); + + // ------- Client side ------- // + let bit_width = 16; + + // Number of blocks per ciphertext + let nb_blocks = bit_width / 2; + println!( + "Number of blocks for the radix decomposition: {:?}", + nb_blocks + ); + + // Scale factor + let scale = 100_000_u64; + + let start = Instant::now(); + // Generate radix keys + let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks); + // Generate key for PBS (without padding) + let wopbs_key = WopbsKey::new_wopbs_key( + &client_key, + &server_key, + &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + ); + println!( + "Key generation done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + let start = Instant::now(); + let encrypted_salaries: Vec<_> = salaries_sorted + .par_iter() // Use par_iter() for parallel iteration + .map(|&salary| client_key.encrypt(salary)) + .collect(); + println!( + "Encryption done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + // ------- Server side ------- // + + // The experience vector is known to the server. + let experience_mean = experience.iter().map(|&exp| exp as f64).sum::() / dataset_size; + let experience_variance: f64 = experience + .iter() + .map(|&exp| ((exp as f64) - experience_mean).powi(2)) + .sum(); + let experience_stddev = experience_variance.sqrt(); + + // Offline: LUT genaration is offline cost. + let lut_gen_start = Instant::now(); + println!("Generating LUT."); + let mut dummy: RadixCiphertext = server_key.create_trivial_radix(2_u64, (nb_blocks).into()); + dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy); + let mut dummy_blocks = dummy.clone().into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + dummy = RadixCiphertext::from_blocks(dummy_blocks[0..(nb_blocks as usize >> 1)].to_vec()); + let sqrt_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| { + if x == 0 { + 1 // avoid division with zero error. + } else { + scale / (x.sqrt() * experience_stddev as u64) + } + }); + let div_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| x / (dataset_size as u64)); + println!( + "LUT generation done in {:?} sec.", + lut_gen_start.elapsed().as_secs_f64() + ); + // Online: Compute the encrypted correlation + + // The salaries vector is encrypted. + let start = Instant::now(); + let total = start; + let salaries_sum_enc = encrypted_salaries.iter().fold( + server_key.create_trivial_radix(0_u64, nb_blocks), + |acc: RadixCiphertext, salary| server_key.add_parallelized(&acc, salary), + ); + // println!("salaries_sum_enc degree: {:?}", salaries_sum_enc.blocks()[0].degree); + let salaries_mean_enc = ct_lut_eval_quantized_no_gen( + salaries_sum_enc, + nb_blocks, + &wopbs_key, + &server_key, + &div_lut, + ); + let mut salaries_mean_enc_blocks = salaries_mean_enc.clone().blocks().to_vec(); + let zero_ct: BaseRadixCiphertext = + server_key.create_trivial_radix(0_u64, nb_blocks >> 1); + let zero_ct_blocks = zero_ct.clone().blocks().to_vec(); + salaries_mean_enc_blocks.extend(zero_ct_blocks.clone()); + let salaries_mean_enc = RadixCiphertext::from_blocks(salaries_mean_enc_blocks); + // Cov = Sum_i^n (salary_i - mean(salary))(experience_i - mean(experience)) + let covariance = encrypted_salaries + .iter() + .zip(experience.iter()) + .map(|(salary_enc, &exp)| { + let x = server_key.sub_parallelized(salary_enc, &salaries_mean_enc); + server_key.scalar_mul_parallelized(&x, exp - (experience_mean as u32)) + }) + .fold( + server_key.create_trivial_radix(0_u64, nb_blocks), + |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), + ); + + // Var_salary = Sum_i^n (salary_i - mean(salary))^2 + let salaries_variance_enc = encrypted_salaries + .iter() + .map(|salary_enc| { + let x = server_key.sub_parallelized(salary_enc, &salaries_mean_enc); + server_key.mul_parallelized(&x, &x) + }) + .fold( + server_key.create_trivial_radix(0_u64, nb_blocks), + |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), + ); + + // sigma_salary (or stddev) = sqrt(var_salary) + // println!("salaries_variance_enc degree: {:?}", + // salaries_variance_enc.blocks()[0].degree); + let salaries_stddev_enc = ct_lut_eval_quantized_no_gen( + salaries_variance_enc, + nb_blocks, + &wopbs_key, + &server_key, + &sqrt_lut, + ); + let mut salaries_stddev_enc_blocks = salaries_stddev_enc.clone().blocks().to_vec(); + salaries_stddev_enc_blocks.extend(zero_ct_blocks); + let salaries_stddev_enc = RadixCiphertext::from_blocks(salaries_stddev_enc_blocks); + let correlation_enc = server_key.mul_parallelized(&salaries_stddev_enc, &covariance); + + println!( + "Finished computing correlation in {:?} sec.", + total.elapsed().as_secs_f64() + ); + + // ------- Client side ------- // + let correlation: u64 = client_key.decrypt(&correlation_enc); + + println!("Correlation: {}", (correlation as f64) / (scale as f64)); +}