Skip to content

Commit

Permalink
Fix LUT preprocessing for LR variants
Browse files Browse the repository at this point in the history
  • Loading branch information
cgouert committed May 7, 2024
1 parent eee2517 commit d8e176f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 11 deletions.
13 changes: 10 additions & 3 deletions src/lr_bior.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ use ripple::common::*;
// use serde::{Deserialize, Serialize};
use tfhe::{
integer::{
// ciphertext::BaseRadixCiphertext,
gen_keys_radix,
wopbs::*,
IntegerCiphertext,
IntegerRadixCiphertext,
RadixCiphertext,
RadixCiphertext
},
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
Degree
},
};

Expand Down Expand Up @@ -115,6 +115,12 @@ fn main() {
[((nb_blocks as usize) - (nb_blocks_msb as usize))..(nb_blocks as usize)];
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks_msb.to_vec());
let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1);
let dummy_msb = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy_msb);
let mut dummy_blocks = dummy_msb.clone().into_blocks().to_vec();
for block in &mut dummy_blocks {
block.degree = Degree::new(3);
}
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks);
let mut msb_luts = Vec::new();
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| {
eval_lut_minus_1(x, &lut_lsb, 2u64.pow((lut_bit_width) as u32))
Expand Down Expand Up @@ -156,7 +162,8 @@ fn main() {
let prediction_blocks_lsb = &prediction_blocks[0..(nb_blocks_lsb as usize)];
let prediction_lsb =
RadixCiphertext::from_blocks(prediction_blocks_lsb.to_vec());
server_key.scalar_add_parallelized(&prediction_lsb, 1)
server_key.scalar_add_parallelized(&prediction_lsb,
1)
},
|| {
let prediction_blocks_msb = &prediction_blocks[((nb_blocks as usize)
Expand Down
17 changes: 13 additions & 4 deletions src/lr_db2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@ use std::time::Instant;
use clap::{App, Arg};
use rayon::prelude::*;
use ripple::common::*;
// use serde::{Deserialize, Serialize};
use tfhe::{
integer::{
// ciphertext::BaseRadixCiphertext,
gen_keys_radix,
wopbs::*,
IntegerCiphertext,
Expand All @@ -16,6 +14,7 @@ use tfhe::{
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
Degree
},
};

Expand Down Expand Up @@ -123,10 +122,20 @@ fn main() {
let dummy_blocks_msb = &dummy_blocks[((j >> 1) as usize)..((nb_blocks as usize) - 3)];
let dummy_lsb = RadixCiphertext::from_blocks(dummy_blocks_lsb.to_vec());
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks_msb.to_vec());
let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1);
let dummy_lsb = server_key.scalar_add_parallelized(&dummy_lsb, 1);
let mut dummy_blocks = dummy_lsb.clone().into_blocks().to_vec();
for block in &mut dummy_blocks {
block.degree = Degree::new(3);
}
let dummy_lsb = RadixCiphertext::from_blocks(dummy_blocks);
let mut dummy_blocks = dummy_msb.clone().into_blocks().to_vec();
for block in &mut dummy_blocks {
block.degree = Degree::new(3);
}
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks);
let mut lsb_luts = Vec::new();
let mut msb_luts = Vec::new();
let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1);
let dummy_lsb = server_key.scalar_add_parallelized(&dummy_lsb, 1);
for lut_lsb in lut_lsbs.iter() {
lsb_luts.push(wopbs_key.generate_lut_radix(&dummy_lsb, |x: u64| eval_lut(x, lut_lsb)));
}
Expand Down
10 changes: 7 additions & 3 deletions src/lr_haar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use tfhe::{
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
Degree
},
};

Expand Down Expand Up @@ -117,9 +118,14 @@ fn main() {
let dummy_2 = server_key.scalar_mul_parallelized(&dummy, 2_u64);
dummy = server_key.add_parallelized(&dummy_2, &dummy);
}
dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy);
let mut dummy_blocks = dummy.clone().into_blocks().to_vec();
for block in &mut dummy_blocks {
block.degree = Degree::new(3);
}
dummy = RadixCiphertext::from_blocks(dummy_blocks);
let dummy_blocks = &dummy.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks.to_vec());
let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1);
let exp_lut_lsb = wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_exp(x, &lut_lsb));
println!(
"LUT generation done in {:?} sec.",
Expand All @@ -146,7 +152,6 @@ fn main() {
let prediction_blocks =
&prediction.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec());
let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1);
// Keyswitch and Bootstrap
prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb);
Expand All @@ -172,7 +177,6 @@ fn main() {
let prediction_blocks =
&prediction.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec());
let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1);
// Keyswitch and Bootstrap
prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb);
Expand Down
9 changes: 8 additions & 1 deletion src/lr_lut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use clap::{App, Arg};
use rayon::prelude::*;
use ripple::common::*;
use tfhe::{
integer::{gen_keys_radix, wopbs::*, RadixCiphertext},
integer::{gen_keys_radix, wopbs::*, RadixCiphertext, IntegerCiphertext, IntegerRadixCiphertext},
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
Degree
},
};

Expand Down Expand Up @@ -92,6 +93,12 @@ fn main() {
let dummy_2 = server_key.scalar_mul_parallelized(&dummy, 2_u64);
dummy = server_key.add_parallelized(&dummy_2, &dummy);
}
dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy);
let mut dummy_blocks = dummy.clone().into_blocks().to_vec();
for block in &mut dummy_blocks {
block.degree = Degree::new(3);
}
dummy = RadixCiphertext::from_blocks(dummy_blocks);
let sigmoid_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| {
sigmoid(x, 2 * precision, precision, bit_width)
});
Expand Down

0 comments on commit d8e176f

Please sign in to comment.