Skip to content

Commit

Permalink
Fix LUT preprocessing for LR variants
Browse files Browse the repository at this point in the history
  • Loading branch information
cgouert committed May 7, 2024
1 parent eee2517 commit 0ecb809
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 23 deletions.
15 changes: 8 additions & 7 deletions src/lr_bior.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,10 @@ use ripple::common::*;
// use serde::{Deserialize, Serialize};
use tfhe::{
integer::{
// ciphertext::BaseRadixCiphertext,
gen_keys_radix,
wopbs::*,
IntegerCiphertext,
IntegerRadixCiphertext,
RadixCiphertext,
gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext,
},
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
},
};
Expand Down Expand Up @@ -115,6 +110,12 @@ fn main() {
[((nb_blocks as usize) - (nb_blocks_msb as usize))..(nb_blocks as usize)];
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks_msb.to_vec());
let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1);
let dummy_msb = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy_msb);
let mut dummy_blocks = dummy_msb.clone().into_blocks().to_vec();
for block in &mut dummy_blocks {
block.degree = Degree::new(3);
}
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks);
let mut msb_luts = Vec::new();
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| {
eval_lut_minus_1(x, &lut_lsb, 2u64.pow((lut_bit_width) as u32))
Expand Down
24 changes: 14 additions & 10 deletions src/lr_db2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,12 @@ use std::time::Instant;
use clap::{App, Arg};
use rayon::prelude::*;
use ripple::common::*;
// use serde::{Deserialize, Serialize};
use tfhe::{
integer::{
// ciphertext::BaseRadixCiphertext,
gen_keys_radix,
wopbs::*,
IntegerCiphertext,
IntegerRadixCiphertext,
RadixCiphertext,
gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext,
},
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
},
};
Expand Down Expand Up @@ -123,10 +117,20 @@ fn main() {
let dummy_blocks_msb = &dummy_blocks[((j >> 1) as usize)..((nb_blocks as usize) - 3)];
let dummy_lsb = RadixCiphertext::from_blocks(dummy_blocks_lsb.to_vec());
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks_msb.to_vec());
let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1);
let dummy_lsb = server_key.scalar_add_parallelized(&dummy_lsb, 1);
let mut dummy_blocks = dummy_lsb.clone().into_blocks().to_vec();
for block in &mut dummy_blocks {
block.degree = Degree::new(3);
}
let dummy_lsb = RadixCiphertext::from_blocks(dummy_blocks);
let mut dummy_blocks = dummy_msb.clone().into_blocks().to_vec();
for block in &mut dummy_blocks {
block.degree = Degree::new(3);
}
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks);
let mut lsb_luts = Vec::new();
let mut msb_luts = Vec::new();
let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1);
let dummy_lsb = server_key.scalar_add_parallelized(&dummy_lsb, 1);
for lut_lsb in lut_lsbs.iter() {
lsb_luts.push(wopbs_key.generate_lut_radix(&dummy_lsb, |x: u64| eval_lut(x, lut_lsb)));
}
Expand Down
11 changes: 7 additions & 4 deletions src/lr_haar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use tfhe::{
RadixCiphertext,
},
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
},
};
Expand Down Expand Up @@ -117,9 +117,14 @@ fn main() {
let dummy_2 = server_key.scalar_mul_parallelized(&dummy, 2_u64);
dummy = server_key.add_parallelized(&dummy_2, &dummy);
}
dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy);
let mut dummy_blocks = dummy.clone().into_blocks().to_vec();
for block in &mut dummy_blocks {
block.degree = Degree::new(3);
}
dummy = RadixCiphertext::from_blocks(dummy_blocks);
let dummy_blocks = &dummy.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks.to_vec());
let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1);
let exp_lut_lsb = wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_exp(x, &lut_lsb));
println!(
"LUT generation done in {:?} sec.",
Expand All @@ -146,7 +151,6 @@ fn main() {
let prediction_blocks =
&prediction.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec());
let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1);
// Keyswitch and Bootstrap
prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb);
Expand All @@ -172,7 +176,6 @@ fn main() {
let prediction_blocks =
&prediction.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec());
let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1);
// Keyswitch and Bootstrap
prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb);
Expand Down
12 changes: 10 additions & 2 deletions src/lr_lut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ use clap::{App, Arg};
use rayon::prelude::*;
use ripple::common::*;
use tfhe::{
integer::{gen_keys_radix, wopbs::*, RadixCiphertext},
integer::{
gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext,
},
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
},
};
Expand Down Expand Up @@ -92,6 +94,12 @@ fn main() {
let dummy_2 = server_key.scalar_mul_parallelized(&dummy, 2_u64);
dummy = server_key.add_parallelized(&dummy_2, &dummy);
}
dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy);
let mut dummy_blocks = dummy.clone().into_blocks().to_vec();
for block in &mut dummy_blocks {
block.degree = Degree::new(3);
}
dummy = RadixCiphertext::from_blocks(dummy_blocks);
let sigmoid_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| {
sigmoid(x, 2 * precision, precision, bit_width)
});
Expand Down

0 comments on commit 0ecb809

Please sign in to comment.