From dd1f49585113c5ddf149188ce8bb01c4c74384ae Mon Sep 17 00:00:00 2001 From: Dimitris Mouris Date: Thu, 28 Mar 2024 12:40:06 +0200 Subject: [PATCH] fix: change bitsizes and add prints --- src/encrypted_lr.rs | 13 ++++++---- src/encrypted_lr_dwt.rs | 55 ++++++++++++++++++++++++++--------------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/src/encrypted_lr.rs b/src/encrypted_lr.rs index bade1b6..24f85ff 100644 --- a/src/encrypted_lr.rs +++ b/src/encrypted_lr.rs @@ -12,7 +12,7 @@ use tfhe::{ fn main() { // ------- Client side ------- // - let bit_width = 32u8; + let bit_width = 16u8; let precision = bit_width >> 2; assert!(precision <= bit_width / 2); @@ -58,18 +58,23 @@ fn main() { // ------- Server side ------- // - // Build LUT for Sigmoid + // Build LUT for Sigmoid -- Offline cost + let start = Instant::now(); + println!("Generating LUT."); let sigmoid_lut = wopbs_key.generate_lut_radix(&encrypted_dataset[0][0], |x: u64| { sigmoid(x, 2 * precision, precision, bit_width) }); + println!("Generated LUT in {:?} sec.", start.elapsed().as_secs_f64()); - let encrypted_dataset_short = encrypted_dataset.get_mut(0..4).unwrap(); + let encrypted_dataset_short = encrypted_dataset.get_mut(0..8).unwrap(); + // Inference let all_probabilities = encrypted_dataset_short .par_iter_mut() .enumerate() .map(|(cnt, sample)| { let start = Instant::now(); + println!("Started inference #{:?}.", cnt); let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into()); for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) { @@ -92,10 +97,8 @@ fn main() { // ------- Client side ------- // let mut total = 0; for (num, (target, probability)) in targets.iter().zip(all_probabilities.iter()).enumerate() { - // let ptxt_probability = client_key.decrypt(probability); let ptxt_probability: u64 = client_key.decrypt(probability); - println!("{:?}", ptxt_probability); let class = (ptxt_probability > quantize(0.5, precision, bit_width)) as usize; println!("[{}] predicted {:?}, target {:?}", num, class, target); if class == *target { diff --git a/src/encrypted_lr_dwt.rs b/src/encrypted_lr_dwt.rs index 1eb8a85..e043699 100644 --- a/src/encrypted_lr_dwt.rs +++ b/src/encrypted_lr_dwt.rs @@ -5,7 +5,12 @@ use rayon::prelude::*; // use serde::{Deserialize, Serialize}; use tfhe::{ integer::{ - gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, + // ciphertext::BaseRadixCiphertext, + gen_keys_radix, + wopbs::*, + IntegerCiphertext, + IntegerRadixCiphertext, + RadixCiphertext, }, shortint::parameters::{ parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, @@ -19,7 +24,7 @@ fn eval_exp(x: u64, exp_map: &Vec) -> u64 { fn main() { // ------- Client side ------- // - let bit_width = 32u8; + let bit_width = 16u8; let precision = bit_width >> 2; assert!(precision <= bit_width / 2); @@ -75,12 +80,27 @@ fn main() { // ------- Server side ------- // - let encrypted_dataset_short = encrypted_dataset.get_mut(0..1).unwrap(); + // let lut_gen_start = Instant::now(); + // println!("Generating LUT."); + // let dummy_blocks = + // &encrypted_dataset[0][0].clone().into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)]; + // let dummy = RadixCiphertext::from_blocks(dummy_blocks.to_vec()); + // let exp_lut_lsb = + // wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_exp(x, &lut_lsb)); + // let exp_lut_msb = + // wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_exp(x, &lut_msb)); + // println!( + // "LUT generation done in {:?} sec.", + // lut_gen_start.elapsed().as_secs_f64() + // ); + + let encrypted_dataset_short = encrypted_dataset.get_mut(0..8).unwrap(); let all_probabilities = encrypted_dataset_short - .iter_mut() + .par_iter_mut() .enumerate() .map(|(cnt, sample)| { let start = Instant::now(); + println!("Started inference #{:?}.", cnt); let mut prediction = server_key.create_trivial_radix(bias_int, (nb_blocks << 1).into()); for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) { @@ -88,13 +108,13 @@ fn main() { prediction = server_key.unchecked_add(&ct_prod, &prediction); } // Truncate - let prediction_blocks = &prediction.clone().into_blocks() - [(nb_blocks as usize)..((nb_blocks << 1) as usize)]; + let prediction_blocks = + &prediction.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)]; let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec()); - // For some reason, the truncation is off by 1... let prediction_msb = server_key.unchecked_scalar_add(&prediction_msb, 1); // Keyswitch and Bootstrap let lut_gen_start = Instant::now(); + println!("Generating LUT."); let exp_lut_lsb = wopbs_key.generate_lut_radix(&prediction_msb, |x: u64| eval_exp(x, &lut_lsb)); let exp_lut_msb = @@ -104,18 +124,14 @@ fn main() { lut_gen_start.elapsed().as_secs_f64() ); prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb); - let (activation_lsb, activation_msb) = rayon::join( - || { - let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb); - wopbs_key.keyswitch_to_pbs_params(&activation_lsb) - }, - || { - let activation_msb = wopbs_key.wopbs(&prediction, &exp_lut_msb); - wopbs_key.keyswitch_to_pbs_params(&activation_msb) - }, - ); - let mut lsb_blocks = activation_lsb.clone().into_blocks(); - let msb_blocks = activation_msb.clone().into_blocks(); + let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb); + let mut lsb_blocks = wopbs_key + .keyswitch_to_pbs_params(&activation_lsb) + .into_blocks(); + let activation_msb = wopbs_key.wopbs(&prediction, &exp_lut_msb); + let msb_blocks = wopbs_key + .keyswitch_to_pbs_params(&activation_msb) + .into_blocks(); lsb_blocks.extend(msb_blocks); let probability = RadixCiphertext::from_blocks(lsb_blocks); @@ -132,7 +148,6 @@ fn main() { let mut total = 0; for (num, (target, probability)) in targets.iter().zip(all_probabilities.iter()).enumerate() { let ptxt_probability: u64 = client_key.decrypt(probability); - println!("{:?}", ptxt_probability); let class = (ptxt_probability > quantize(0.5, precision, bit_width)) as usize; println!("[{}] predicted {:?}, target {:?}", num, class, target); if class == *target {