Skip to content

Commit

Permalink
fix: change bitsizes and add prints
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Mar 28, 2024
1 parent 7d593fe commit dd1f495
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
13 changes: 8 additions & 5 deletions src/encrypted_lr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tfhe::{

fn main() {
// ------- Client side ------- //
let bit_width = 32u8;
let bit_width = 16u8;
let precision = bit_width >> 2;
assert!(precision <= bit_width / 2);

Expand Down Expand Up @@ -58,18 +58,23 @@ fn main() {

// ------- Server side ------- //

// Build LUT for Sigmoid
// Build LUT for Sigmoid -- Offline cost
let start = Instant::now();
println!("Generating LUT.");
let sigmoid_lut = wopbs_key.generate_lut_radix(&encrypted_dataset[0][0], |x: u64| {
sigmoid(x, 2 * precision, precision, bit_width)
});
println!("Generated LUT in {:?} sec.", start.elapsed().as_secs_f64());

let encrypted_dataset_short = encrypted_dataset.get_mut(0..4).unwrap();
let encrypted_dataset_short = encrypted_dataset.get_mut(0..8).unwrap();

// Inference
let all_probabilities = encrypted_dataset_short
.par_iter_mut()
.enumerate()
.map(|(cnt, sample)| {
let start = Instant::now();
println!("Started inference #{:?}.", cnt);

let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into());
for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) {
Expand All @@ -92,10 +97,8 @@ fn main() {
// ------- Client side ------- //
let mut total = 0;
for (num, (target, probability)) in targets.iter().zip(all_probabilities.iter()).enumerate() {
// let ptxt_probability = client_key.decrypt(probability);
let ptxt_probability: u64 = client_key.decrypt(probability);

println!("{:?}", ptxt_probability);
let class = (ptxt_probability > quantize(0.5, precision, bit_width)) as usize;
println!("[{}] predicted {:?}, target {:?}", num, class, target);
if class == *target {
Expand Down
55 changes: 35 additions & 20 deletions src/encrypted_lr_dwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ use rayon::prelude::*;
// use serde::{Deserialize, Serialize};
use tfhe::{
integer::{
gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext,
// ciphertext::BaseRadixCiphertext,
gen_keys_radix,
wopbs::*,
IntegerCiphertext,
IntegerRadixCiphertext,
RadixCiphertext,
},
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
Expand All @@ -19,7 +24,7 @@ fn eval_exp(x: u64, exp_map: &Vec<u64>) -> u64 {

fn main() {
// ------- Client side ------- //
let bit_width = 32u8;
let bit_width = 16u8;
let precision = bit_width >> 2;
assert!(precision <= bit_width / 2);

Expand Down Expand Up @@ -75,26 +80,41 @@ fn main() {

// ------- Server side ------- //

let encrypted_dataset_short = encrypted_dataset.get_mut(0..1).unwrap();
// let lut_gen_start = Instant::now();
// println!("Generating LUT.");
// let dummy_blocks =
// &encrypted_dataset[0][0].clone().into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
// let dummy = RadixCiphertext::from_blocks(dummy_blocks.to_vec());
// let exp_lut_lsb =
// wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_exp(x, &lut_lsb));
// let exp_lut_msb =
// wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_exp(x, &lut_msb));
// println!(
// "LUT generation done in {:?} sec.",
// lut_gen_start.elapsed().as_secs_f64()
// );

let encrypted_dataset_short = encrypted_dataset.get_mut(0..8).unwrap();
let all_probabilities = encrypted_dataset_short
.iter_mut()
.par_iter_mut()
.enumerate()
.map(|(cnt, sample)| {
let start = Instant::now();
println!("Started inference #{:?}.", cnt);

let mut prediction = server_key.create_trivial_radix(bias_int, (nb_blocks << 1).into());
for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) {
let ct_prod = server_key.smart_scalar_mul(s, weight);
prediction = server_key.unchecked_add(&ct_prod, &prediction);
}
// Truncate
let prediction_blocks = &prediction.clone().into_blocks()
[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let prediction_blocks =
&prediction.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec());
// For some reason, the truncation is off by 1...
let prediction_msb = server_key.unchecked_scalar_add(&prediction_msb, 1);
// Keyswitch and Bootstrap
let lut_gen_start = Instant::now();
println!("Generating LUT.");
let exp_lut_lsb =
wopbs_key.generate_lut_radix(&prediction_msb, |x: u64| eval_exp(x, &lut_lsb));
let exp_lut_msb =
Expand All @@ -104,18 +124,14 @@ fn main() {
lut_gen_start.elapsed().as_secs_f64()
);
prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let (activation_lsb, activation_msb) = rayon::join(
|| {
let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb);
wopbs_key.keyswitch_to_pbs_params(&activation_lsb)
},
|| {
let activation_msb = wopbs_key.wopbs(&prediction, &exp_lut_msb);
wopbs_key.keyswitch_to_pbs_params(&activation_msb)
},
);
let mut lsb_blocks = activation_lsb.clone().into_blocks();
let msb_blocks = activation_msb.clone().into_blocks();
let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb);
let mut lsb_blocks = wopbs_key
.keyswitch_to_pbs_params(&activation_lsb)
.into_blocks();
let activation_msb = wopbs_key.wopbs(&prediction, &exp_lut_msb);
let msb_blocks = wopbs_key
.keyswitch_to_pbs_params(&activation_msb)
.into_blocks();
lsb_blocks.extend(msb_blocks);
let probability = RadixCiphertext::from_blocks(lsb_blocks);

Expand All @@ -132,7 +148,6 @@ fn main() {
let mut total = 0;
for (num, (target, probability)) in targets.iter().zip(all_probabilities.iter()).enumerate() {
let ptxt_probability: u64 = client_key.decrypt(probability);
println!("{:?}", ptxt_probability);
let class = (ptxt_probability > quantize(0.5, precision, bit_width)) as usize;
println!("[{}] predicted {:?}, target {:?}", num, class, target);
if class == *target {
Expand Down

0 comments on commit dd1f495

Please sign in to comment.