diff --git a/src/correlation_haar.rs b/src/correlation_haar.rs index 572198e..538aa64 100644 --- a/src/correlation_haar.rs +++ b/src/correlation_haar.rs @@ -23,7 +23,7 @@ fn main() { // ------- Client side ------- // let bit_width = 16; - let precision = 6; + let precision = 0; // Number of blocks per ciphertext let nb_blocks = bit_width / 2; @@ -67,7 +67,7 @@ fn main() { .iter() .map(|&exp| ((exp as f64) - experience_mean).powi(2)) .sum(); - let experience_stddev = experience_variance.sqrt(); + let experience_stddev = quantize(experience_variance.sqrt(), precision, bit_width as u8); // Offline: LUT genaration is offline cost. let lut_gen_start = Instant::now(); @@ -86,11 +86,7 @@ fn main() { bit_width as u8, bit_width as u8, &|x: f64| { - if x.abs() < 0.05 { - 1.0 // avoid division with zero error. - } else { - scale as f64 / (x.sqrt() * experience_stddev) - } + scale as f64 / (x.sqrt() * experience_stddev as f64) }, ); let haar_lsb_lut_sqrt = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_lsb)); @@ -100,7 +96,7 @@ fn main() { precision, bit_width as u8, bit_width as u8, - &|x: f64| x / dataset_size, + &|x: f64| x / dataset_size ); let haar_lsb_lut_div = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_lsb)); let haar_msb_lut_div = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_msb)); @@ -126,7 +122,6 @@ fn main() { &haar_lsb_lut_div, &haar_msb_lut_div, ); - // Cov = Sum_i^n (salary_i - mean(salary))(experience_i - mean(experience)) let covariance = encrypted_salaries .iter() @@ -151,7 +146,6 @@ fn main() { server_key.create_trivial_radix(0_u64, nb_blocks), |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), ); - // sigma_salary (or stddev) = sqrt(var_salary) // println!("salaries_variance_enc degree: {:?}", salaries_variance_enc.blocks()[0].degree); let salaries_stddev_enc = ct_lut_eval_haar_no_gen( @@ -172,6 +166,5 @@ fn main() { // ------- Client side ------- // let correlation: u64 = client_key.decrypt(&correlation_enc); let correlation_final: f64 = unquantize(correlation, precision, bit_width as u8); - println!("Correlation: {}", correlation_final / (scale as f64)); }