Skip to content

Commit

Permalink
Optimize encrypted correlation baseline and start adding timings for …
Browse files Browse the repository at this point in the history
…primitive operations

Co-authored-by: Dimitris Mouris <[email protected]>
  • Loading branch information
cgouert and jimouris committed May 1, 2024
1 parent ac379f4 commit d1af0d8
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 23 deletions.
10 changes: 8 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,17 @@ name = "correlation_ptxt"
path = "src/correlation_ptxt.rs"

[[bin]]
name = "correlation_haar"
path = "src/correlation_haar.rs"
name = "correlation_lut"
path = "src/correlation_lut.rs"

# Euclidean Distance

[[bin]]
name = "euclidean"
path = "src/euclidean.rs"

# Primitive Timings

[[bin]]
name = "primitive_ops"
path = "src/primitive_ops.rs"
85 changes: 66 additions & 19 deletions src/correlation_haar.rs → src/correlation_lut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ fn main() {
let salaries = &data[1];
let dataset_size = salaries.len() as f64;

let mut salary_sorted = salaries.clone();
salary_sorted.sort();
let mut salaries_sorted = salaries.clone();
salaries_sorted.sort();

// ------- Client side ------- //
let bit_width = 16;
let bit_width = 20;

// Number of blocks per ciphertext
let nb_blocks = bit_width / 2;
Expand All @@ -30,6 +30,9 @@ fn main() {
nb_blocks
);

// Scale factor
let scale = 100_000_u64;

let start = Instant::now();
// Generate radix keys
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks);
Expand All @@ -45,7 +48,7 @@ fn main() {
);

let start = Instant::now();
let encrypted_salaries: Vec<_> = salary_sorted
let encrypted_salaries: Vec<_> = salaries_sorted
.par_iter() // Use par_iter() for parallel iteration
.map(|&salary| client_key.encrypt(salary))
.collect();
Expand All @@ -55,34 +58,60 @@ fn main() {
);

// ------- Server side ------- //
// TODO: Move LUT gens up here

// Compute the encrypted correlation
let start = Instant::now();
println!("Starting computing mean");

// The experience vector is known to the server.
let experience_mean = experience.iter().map(|&exp| exp as f64).sum::<f64>() / dataset_size;
let experience_variance: f64 = experience
.iter()
.map(|&exp| ((exp as f64) - experience_mean).powi(2))
.sum();
let experience_stddev = experience_variance.sqrt();

// Offline: LUT genaration is offline cost.
let mut dummy_ct: RadixCiphertext = server_key.create_trivial_radix(0_u64, nb_blocks);
let dummy_ct_2 = server_key.create_trivial_radix(0_u64, nb_blocks);
for _ in 0..encrypted_salaries.len() {
dummy_ct = server_key.add_parallelized(&dummy_ct, &dummy_ct_2);
}
let div_lut = wopbs_key.generate_lut_radix(&dummy_ct, |x: u64| x / (dataset_size as u64));
dummy_ct = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy_ct);
dummy_ct = wopbs_key.wopbs(&dummy_ct, &div_lut);
dummy_ct = wopbs_key.keyswitch_to_pbs_params(&dummy_ct);
let mut dummy_acc: RadixCiphertext = server_key.create_trivial_radix(0_u64, nb_blocks);
for _ in 0..encrypted_salaries.len() {
let mut dummy_ct_3 = server_key.sub_parallelized(&dummy_ct_2, &dummy_ct);
dummy_ct_3 = server_key.mul_parallelized(&dummy_ct_3, &dummy_ct_3);
dummy_acc = server_key.add_parallelized(&dummy_acc, &dummy_ct_3);
}
let sqrt_lut = wopbs_key.generate_lut_radix(&dummy_acc, |x: u64| {
if x == 0 {
1 // avoid division with zero error.
} else {
scale / (x.sqrt() * experience_stddev as u64)
}
});

// Online: Compute the encrypted correlation

// The salaries vector is encrypted.
println!("- Starting computing mean");
let start = Instant::now();
let total = start;
let mut salaries_sum_enc = encrypted_salaries.iter().fold(
server_key.create_trivial_radix(0_u64, nb_blocks),
|acc: RadixCiphertext, salary| server_key.add_parallelized(&acc, salary),
);
let div_lut =
wopbs_key.generate_lut_radix(&salaries_sum_enc, |x: u64| x / (dataset_size as u64));
salaries_sum_enc = wopbs_key.keyswitch_to_wopbs_params(&server_key, &salaries_sum_enc);
let mut salaries_mean_enc = wopbs_key.wopbs(&salaries_sum_enc, &div_lut);
salaries_mean_enc = wopbs_key.keyswitch_to_pbs_params(&salaries_mean_enc);

println!(
"Finished computing mean in {:?} sec.",
"- Finished computing mean in {:?} sec.",
start.elapsed().as_secs_f64()
);

// Cov = Sum_i^n (salary_i - mean(salary))(experience_i - mean(experience))
println!("- Starting computing covariance");
let start = Instant::now();
let covariance = encrypted_salaries
.iter()
.zip(experience.iter())
Expand All @@ -94,7 +123,14 @@ fn main() {
server_key.create_trivial_radix(0_u64, nb_blocks),
|acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff),
);
println!(
"- Finished computing covariance in {:?} sec.",
start.elapsed().as_secs_f64()
);

// Var_salary = Sum_i^n (salary_i - mean(salary))^2
println!("- Starting computing variance");
let start = Instant::now();
let mut salaries_variance_enc = encrypted_salaries
.iter()
.map(|salary_enc| {
Expand All @@ -105,20 +141,31 @@ fn main() {
server_key.create_trivial_radix(0_u64, nb_blocks),
|acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff),
);
println!(
"- Finished computing variance in {:?} sec.",
start.elapsed().as_secs_f64()
);

let sqrt_lut = wopbs_key.generate_lut_radix(&salaries_variance_enc, |x: u64| x.sqrt());

// sigma_salary (or stddev) = sqrt(var_salary)
salaries_variance_enc =
wopbs_key.keyswitch_to_wopbs_params(&server_key, &salaries_variance_enc);
println!("- Starting computing LUT");
let start = Instant::now();
let mut salaries_stddev_enc = wopbs_key.wopbs(&salaries_variance_enc, &sqrt_lut);
salaries_stddev_enc = wopbs_key.keyswitch_to_pbs_params(&salaries_stddev_enc);
println!(
"- Finished computing LUT in {:?} sec.",
start.elapsed().as_secs_f64()
);
let correlation_enc = server_key.mul_parallelized(&salaries_stddev_enc, &covariance);

let divisor_enc =
server_key.scalar_mul_parallelized(&salaries_stddev_enc, experience_stddev as u32);
let correlation_enc = server_key.div_parallelized(&covariance, &divisor_enc);
println!(
"Finished computing correlation in {:?} sec.",
total.elapsed().as_secs_f64()
);

// ------- Client side ------- //
let correlation: u64 = client_key.decrypt(&correlation_enc);

println!("correlation: {}", correlation);
println!("Correlation: {}", (correlation as f64) / (scale as f64));
}
4 changes: 2 additions & 2 deletions src/correlation_ptxt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ fn pearson_correlation(x: &[u32], y: &[u32]) -> f64 {
let x_mean = x.iter().map(|&xi| xi as f64).sum::<f64>() / n;
let y_mean = y.iter().map(|&yi| yi as f64).sum::<f64>() / n;

let sum_xy: f64 = x
let covariance: f64 = x
.iter()
.zip(y.iter())
.map(|(&xi, &yi)| ((xi as f64) - x_mean) * ((yi as f64) - y_mean))
.sum();
let variance_x: f64 = x.iter().map(|&xi| ((xi as f64) - x_mean).powi(2)).sum();
let variance_y: f64 = y.iter().map(|&yi| ((yi as f64) - y_mean).powi(2)).sum();

sum_xy / (variance_x.sqrt() * variance_y.sqrt())
covariance / (variance_x.sqrt() * variance_y.sqrt())
}

fn main() {
Expand Down
145 changes: 145 additions & 0 deletions src/primitive_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use std::time::Instant;

use dwt::{transform, wavelet::Haar, Operation};
use ripple::common::*;
use tfhe::{
integer::{
gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext,
},
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
},
};

pub fn haar_square(
table_size: u8,
input_precision: u8,
output_precision: u8,
bit_width: u8,
) -> (Vec<u64>, Vec<u64>) {
let max = 1 << bit_width;
let mut data = Vec::new();
for x in 0..max {
let x = unquantize(x, input_precision, bit_width);
let square = x * x;
data.push(square);
}
data.rotate_right(1 << (bit_width - 1));
transform(
&mut data,
Operation::Forward,
&Haar::new(),
(bit_width - table_size) as usize,
);
let coef_len = 1 << table_size;
let scalar = 2f64.powf(-((bit_width - table_size) as f64) / 2f64);
let mut haar: Vec<u64> = data
.get(0..coef_len)
.unwrap()
.iter()
.map(|x| quantize(scalar * x, output_precision, bit_width))
.collect();
haar.rotate_right(1 << (table_size - 1));
let mask = (1 << (bit_width / 2)) - 1;
let lsb = haar.iter().map(|x| x & mask).collect();
let msb = haar.iter().map(|x| x >> (bit_width / 2) & mask).collect();
(lsb, msb)
}

fn eval_lut(x: u64, lut_map: &Vec<u64>) -> u64 {
lut_map[x as usize]
}

fn main() {
// ------- Client side ------- //
let bit_width = 20;

// Number of blocks per ciphertext
let nb_blocks = bit_width / 2;
println!(
"Number of blocks for the radix decomposition: {:?}",
nb_blocks
);

let start = Instant::now();
// Generate radix keys
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks);
// Generate key for PBS (without padding)
let wopbs_key = WopbsKey::new_wopbs_key(
&client_key,
&server_key,
&WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
);
println!(
"Key generation done in {:?} sec.",
start.elapsed().as_secs_f64()
);

let x = 5_u64;
// let y = 10_u64;
let x_ct = client_key.encrypt(x);
// let y_ct = client_key.encrypt(y);

// ------- Server side ------- //

// 1. Square
println!("\n1. Square");

// 1.1. Square using multiplication
let start = Instant::now();
let square_ct = server_key.mul_parallelized(&x_ct, &x_ct);
println!("Ct-Ct Mult in {:?} sec.", start.elapsed().as_secs_f64());
let prod: u64 = client_key.decrypt(&square_ct);

// 1.2. Square using LUT
let square_lut = wopbs_key.generate_lut_radix(&x_ct, |x: u64| x * x);
let start = Instant::now();
let x_ct_ks = wopbs_key.keyswitch_to_wopbs_params(&server_key, &x_ct);
let mut square_ct = wopbs_key.wopbs(&x_ct_ks, &square_lut);
square_ct = wopbs_key.keyswitch_to_pbs_params(&square_ct);
println!("LUT Square in {:?} sec.", start.elapsed().as_secs_f64());
let lut_prod: u64 = client_key.decrypt(&square_ct);

// 1.3. Square using Haar DWT LUT
let (haar_lsb, haar_msb) = haar_square((bit_width >> 1) as u8, 8_u8, 16_u8, bit_width as u8);
dbg!(&haar_lsb);
dbg!(&haar_msb);
let dummy: RadixCiphertext = server_key.create_trivial_radix(0_u64, nb_blocks);
let dummy_blocks = &dummy.into_blocks()[(nb_blocks >> 1)..nb_blocks];
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks.to_vec());
let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1);
let haar_lsb_lut = wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut(x, &haar_lsb));
let haar_msb_lut = wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut(x, &haar_msb));

let start = Instant::now();
// Truncate x
let x_truncated_blocks = &x_ct.into_blocks()[(nb_blocks >> 1)..nb_blocks];
let x_truncated = RadixCiphertext::from_blocks(x_truncated_blocks.to_vec());
let x_truncated = server_key.scalar_add_parallelized(&x_truncated, 1);
let x_truncated_ks = wopbs_key.keyswitch_to_wopbs_params(&server_key, &x_truncated);
let (square_lsb, square_msb) = rayon::join(
|| {
let square_lsb = wopbs_key.wopbs(&x_truncated_ks, &haar_lsb_lut);
wopbs_key.keyswitch_to_pbs_params(&square_lsb)
},
|| {
let square_msb = wopbs_key.wopbs(&x_truncated_ks, &haar_msb_lut);
wopbs_key.keyswitch_to_pbs_params(&square_msb)
},
);
let mut square_lsb_blocks = square_lsb.into_blocks();
square_lsb_blocks.extend(square_msb.into_blocks());
let square_ct_haar = RadixCiphertext::from_blocks(square_lsb_blocks.to_vec());

println!(
"Haar LUT Square in {:?} sec.",
start.elapsed().as_secs_f64()
);
let dwt_lut_prod: u64 = client_key.decrypt(&square_ct_haar);

println!(
"--- Exact: {:?}, LUT: {:?}, DWT LUT: {:?}",
prod, lut_prod, dwt_lut_prod
);
}

0 comments on commit d1af0d8

Please sign in to comment.