diff --git a/src/lr_bior.rs b/src/lr_bior.rs index 9f3ab02..239c5e2 100644 --- a/src/lr_bior.rs +++ b/src/lr_bior.rs @@ -6,15 +6,10 @@ use ripple::common::*; // use serde::{Deserialize, Serialize}; use tfhe::{ integer::{ - // ciphertext::BaseRadixCiphertext, - gen_keys_radix, - wopbs::*, - IntegerCiphertext, - IntegerRadixCiphertext, - RadixCiphertext, + gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, }, shortint::parameters::{ - parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, PARAM_MESSAGE_2_CARRY_2_KS_PBS, }, }; @@ -115,6 +110,12 @@ fn main() { [((nb_blocks as usize) - (nb_blocks_msb as usize))..(nb_blocks as usize)]; let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks_msb.to_vec()); let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1); + let dummy_msb = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy_msb); + let mut dummy_blocks = dummy_msb.clone().into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks); let mut msb_luts = Vec::new(); msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| { eval_lut_minus_1(x, &lut_lsb, 2u64.pow((lut_bit_width) as u32)) diff --git a/src/lr_db2.rs b/src/lr_db2.rs index f0fd559..b399a2d 100644 --- a/src/lr_db2.rs +++ b/src/lr_db2.rs @@ -3,18 +3,12 @@ use std::time::Instant; use clap::{App, Arg}; use rayon::prelude::*; use ripple::common::*; -// use serde::{Deserialize, Serialize}; use tfhe::{ integer::{ - // ciphertext::BaseRadixCiphertext, - gen_keys_radix, - wopbs::*, - IntegerCiphertext, - IntegerRadixCiphertext, - RadixCiphertext, + gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, }, shortint::parameters::{ - parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, PARAM_MESSAGE_2_CARRY_2_KS_PBS, }, }; @@ -123,10 +117,20 @@ fn main() { let dummy_blocks_msb = &dummy_blocks[((j >> 1) as usize)..((nb_blocks as usize) - 3)]; let dummy_lsb = RadixCiphertext::from_blocks(dummy_blocks_lsb.to_vec()); let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks_msb.to_vec()); - let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1); - let dummy_lsb = server_key.scalar_add_parallelized(&dummy_lsb, 1); + let mut dummy_blocks = dummy_lsb.clone().into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + let dummy_lsb = RadixCiphertext::from_blocks(dummy_blocks); + let mut dummy_blocks = dummy_msb.clone().into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks); let mut lsb_luts = Vec::new(); let mut msb_luts = Vec::new(); + let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1); + let dummy_lsb = server_key.scalar_add_parallelized(&dummy_lsb, 1); for lut_lsb in lut_lsbs.iter() { lsb_luts.push(wopbs_key.generate_lut_radix(&dummy_lsb, |x: u64| eval_lut(x, lut_lsb))); } diff --git a/src/lr_haar.rs b/src/lr_haar.rs index 836a560..9ce58d2 100644 --- a/src/lr_haar.rs +++ b/src/lr_haar.rs @@ -14,7 +14,7 @@ use tfhe::{ RadixCiphertext, }, shortint::parameters::{ - parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, PARAM_MESSAGE_2_CARRY_2_KS_PBS, }, }; @@ -117,9 +117,14 @@ fn main() { let dummy_2 = server_key.scalar_mul_parallelized(&dummy, 2_u64); dummy = server_key.add_parallelized(&dummy_2, &dummy); } + dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy); + let mut dummy_blocks = dummy.clone().into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + dummy = RadixCiphertext::from_blocks(dummy_blocks); let dummy_blocks = &dummy.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)]; let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks.to_vec()); - let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1); let exp_lut_lsb = wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_exp(x, &lut_lsb)); println!( "LUT generation done in {:?} sec.", @@ -146,7 +151,6 @@ fn main() { let prediction_blocks = &prediction.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)]; let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec()); - let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1); // Keyswitch and Bootstrap prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb); let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb); @@ -172,7 +176,6 @@ fn main() { let prediction_blocks = &prediction.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)]; let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec()); - let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1); // Keyswitch and Bootstrap prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb); let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb); diff --git a/src/lr_lut.rs b/src/lr_lut.rs index 3319427..167f70d 100644 --- a/src/lr_lut.rs +++ b/src/lr_lut.rs @@ -4,9 +4,11 @@ use clap::{App, Arg}; use rayon::prelude::*; use ripple::common::*; use tfhe::{ - integer::{gen_keys_radix, wopbs::*, RadixCiphertext}, + integer::{ + gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, + }, shortint::parameters::{ - parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, PARAM_MESSAGE_2_CARRY_2_KS_PBS, }, }; @@ -92,6 +94,12 @@ fn main() { let dummy_2 = server_key.scalar_mul_parallelized(&dummy, 2_u64); dummy = server_key.add_parallelized(&dummy_2, &dummy); } + dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy); + let mut dummy_blocks = dummy.clone().into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + dummy = RadixCiphertext::from_blocks(dummy_blocks); let sigmoid_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| { sigmoid(x, 2 * precision, precision, bit_width) });