From d01660b096ae85afde0e590ecc048636f207bf41 Mon Sep 17 00:00:00 2001 From: Dimitris Mouris Date: Wed, 27 Mar 2024 10:41:55 +0200 Subject: [PATCH] chore: Minor fixes and remove unused files --- Cargo.toml | 12 ---------- src/common.rs | 5 +++-- src/encrypted_lr_dwt.rs | 1 - src/lut_split_test.rs | 49 ----------------------------------------- src/lut_test.rs | 25 --------------------- 5 files changed, 3 insertions(+), 89 deletions(-) delete mode 100644 src/lut_split_test.rs delete mode 100644 src/lut_test.rs diff --git a/Cargo.toml b/Cargo.toml index 488fffb..f1f4334 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,10 +20,6 @@ assert = "0.7" x86 = ["tfhe/x86_64-unix"] aarch64 = ["tfhe/aarch64-unix"] -[[bin]] -name = "haar" -path = "src/haar.rs" - [[bin]] name = "float_lr" path = "src/float_lr.rs" @@ -39,11 +35,3 @@ path = "src/encrypted_lr.rs" [[bin]] name = "encrypted_lr_dwt" path = "src/encrypted_lr_dwt.rs" - -[[bin]] -name = "lut_test" -path = "src/lut_test.rs" - -[[bin]] -name = "lut_split_test" -path = "src/lut_split_test.rs" \ No newline at end of file diff --git a/src/common.rs b/src/common.rs index ee4f18f..a4fddcf 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,6 +1,7 @@ -use dwt::{transform, wavelet::Haar, Operation}; use std::fs::File; +use dwt::{transform, wavelet::Haar, Operation}; + pub fn to_signed(x: u64, bit_width: u8) -> i64 { if x > (1u64 << (bit_width - 1)) { (x as i128 - (1i128 << bit_width)) as i64 @@ -132,7 +133,7 @@ pub fn haar(precision: u8, bit_width: u8) -> (Vec, Vec) { let haar = data .get(0..coef_len) .unwrap() - .into_iter() + .iter() .map(|x| quantize(*x, precision, bit_width)); let lsb = haar.clone().map(|x| x & 0xFF).collect(); let msb = haar.map(|x| x >> (bit_width / 2) & 0xFF).collect(); diff --git a/src/encrypted_lr_dwt.rs b/src/encrypted_lr_dwt.rs index 877020f..2ad2027 100644 --- a/src/encrypted_lr_dwt.rs +++ b/src/encrypted_lr_dwt.rs @@ -1,5 +1,4 @@ use std::time::Instant; -// collections::HashMap, fs::File, io::BufReader, use fhe_lut::common::*; use rayon::prelude::*; diff --git a/src/lut_split_test.rs b/src/lut_split_test.rs deleted file mode 100644 index 08a3639..0000000 --- a/src/lut_split_test.rs +++ /dev/null @@ -1,49 +0,0 @@ -use std::time::Instant; - -use tfhe::{ - integer::{ - gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, - }, - shortint::parameters::{ - parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - PARAM_MESSAGE_2_CARRY_2_KS_PBS, - }, -}; - -fn main() { - let nb_block = 4; - let msg = 14; - let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_block); - let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS); - - let ct = cks.encrypt(msg); - let ct = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct); - - let lut_lsb = wopbs_key.generate_lut_radix(&ct, |x| u64::pow(x, 2) % (1 << 4)); - let lut_msb = wopbs_key.generate_lut_radix(&ct, |x| (u64::pow(x, 2) >> 4) % (1 << 4)); - - let start = Instant::now(); - let (ct_res_lsb, ct_res_msb) = rayon::join( - || { - let ct_res_lsb = wopbs_key.wopbs(&ct, &lut_lsb); - wopbs_key.keyswitch_to_pbs_params(&ct_res_lsb) - }, - || { - let ct_res_msb = wopbs_key.wopbs(&ct, &lut_msb); - wopbs_key.keyswitch_to_pbs_params(&ct_res_msb) - }, - ); - - let mut lsb_blocks = ct_res_lsb.clone().into_blocks(); - let msb_blocks = ct_res_msb.clone().into_blocks(); - lsb_blocks.extend(msb_blocks); - let _ct_res = RadixCiphertext::from_blocks(lsb_blocks); - let duration = start.elapsed(); - println!("PBS time: {:?}", duration); - let res_lsb: u64 = cks.decrypt(&ct_res_lsb); - let res_msb: u64 = cks.decrypt(&ct_res_msb); - - assert_eq!(res_lsb, u64::pow(msg, 2) % (1 << 4)); - assert_eq!(res_msb, (u64::pow(msg, 2) >> 4) % (1 << 4)); - assert_eq!((res_msb << 4) + res_lsb, u64::pow(msg, 2)); -} diff --git a/src/lut_test.rs b/src/lut_test.rs deleted file mode 100644 index 7bef123..0000000 --- a/src/lut_test.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::time::Instant; - -use tfhe::{ - integer::{gen_keys_radix, wopbs::*}, - shortint::parameters::{ - parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - PARAM_MESSAGE_2_CARRY_2_KS_PBS, - }, -}; - -fn main() { - let nb_block = 8; - let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_block); - let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS); - let ct = cks.encrypt(2_u64); - let ct = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct); - let lut = wopbs_key.generate_lut_radix(&ct, |x| 5 + 2 * x); - let start = Instant::now(); - let ct_res = wopbs_key.wopbs(&ct, &lut); - let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); - let duration = start.elapsed(); - println!("PBS time: {:?}", duration); - let res: u64 = cks.decrypt(&ct_res); - assert_eq!(res, 9); -}