Skip to content

Commit

Permalink
Update encrypted variants
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Mar 27, 2024
1 parent 315132d commit 0700524
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 95 deletions.
24 changes: 7 additions & 17 deletions src/encrypted_lr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tfhe::{

fn main() {
// ------- Client side ------- //
let bit_width = 16u8;
let bit_width = 32u8;
let precision = bit_width >> 2;
assert!(precision <= bit_width / 2);

Expand All @@ -36,10 +36,10 @@ fn main() {

let (weights, bias) = load_weights_and_biases();
let (weights_int, bias_int) = quantize_weights_and_bias(&weights, bias, precision, bit_width);
let (iris_dataset, targets) = prepare_penguins_dataset();
let (dataset, targets) = prepare_penguins_dataset();

let start = Instant::now();
let mut encrypted_dataset: Vec<Vec<_>> = iris_dataset
let mut encrypted_dataset: Vec<Vec<_>> = dataset
.par_iter() // Use par_iter() for parallel iteration
.map(|sample| {
sample
Expand Down Expand Up @@ -73,24 +73,13 @@ fn main() {

let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into());
for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) {
let mut d: u64 = client_key.decrypt(s);
println!("s: {:?}", d);
println!("weight: {:?}", weight);
let ct_prod = server_key.smart_scalar_mul(s, weight);
d = client_key.decrypt(&ct_prod);
println!("mul result: {:?}", d);
prediction = server_key.unchecked_add(&ct_prod, &prediction);
// FIXME: DEBUG
d = client_key.decrypt(&prediction);
println!("MAC result: {:?}", d);
println!();
}
println!();
prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction);
let activation = wopbs_key.wopbs(&prediction, &sigmoid_lut);
let probability = wopbs_key.keyswitch_to_pbs_params(&activation);
let d: u64 = client_key.decrypt(&probability);
println!("Sigmoid result: {:?}", d);

println!(
"Finished inference #{:?} in {:?} sec.",
cnt,
Expand All @@ -99,12 +88,13 @@ fn main() {
probability
})
.collect::<Vec<_>>();
// }

// ------- Client side ------- //
let mut total = 0;
for (num, (target, probability)) in targets.iter().zip(all_probabilities.iter()).enumerate() {
let ptxt_probability = client_key.decrypt(probability);
// let ptxt_probability = client_key.decrypt(probability);
let ptxt_probability: u64 = client_key.decrypt(probability);

println!("{:?}", ptxt_probability);
let class = (ptxt_probability > quantize(0.5, precision, bit_width)) as usize;
println!("[{}] predicted {:?}, target {:?}", num, class, target);
Expand Down
129 changes: 51 additions & 78 deletions src/encrypted_lr_dwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::time::Instant;

use fhe_lut::common::*;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
// use serde::{Deserialize, Serialize};
use tfhe::{
integer::{
gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext,
Expand All @@ -13,25 +13,13 @@ use tfhe::{
},
};

#[derive(Debug, Serialize, Deserialize)]
struct KeyValue {
key: u64,
value: u64,
}

fn eval_exp(x: u64, exp_map: &Vec<u64>) -> u64 {
exp_map[x as usize]
}

fn main() {
// let reader = BufReader::new(File::open("lut16_quantized_lsb.json").unwrap());
// let lut_lsb: HashMap<u64, u64> = serde_json::from_reader(reader).unwrap();

// let reader = BufReader::new(File::open("lut16_quantized_msb.json").unwrap());
// let lut_msb: HashMap<u64, u64> = serde_json::from_reader(reader).unwrap();

// ------- Client side ------- //
let bit_width = 16u8;
let bit_width = 32u8;
let precision = bit_width >> 2;
assert!(precision <= bit_width / 2);

Expand All @@ -56,22 +44,18 @@ fn main() {
start.elapsed().as_secs_f64()
);

let (weights, biases) = load_weights_and_biases();
let (weights_int, bias_int) =
quantize_weights_and_biases(&weights, &biases, precision, bit_width);
let (iris_dataset, targets) = prepare_iris_dataset();
let num_features = iris_dataset[0].len();
let (means, stds) = means_and_stds(&iris_dataset, num_features);
let (weights, bias) = load_weights_and_biases();
let (weights_int, bias_int) = quantize_weights_and_bias(&weights, bias, precision, bit_width);
let (dataset, targets) = prepare_penguins_dataset();

let start = Instant::now();
let mut encrypted_dataset: Vec<Vec<_>> = iris_dataset
let mut encrypted_dataset: Vec<Vec<_>> = dataset
.par_iter() // Use par_iter() for parallel iteration
.map(|sample| {
sample
.par_iter()
.zip(means.par_iter().zip(stds.par_iter()))
.map(|(&s, (mean, std))| {
let quantized = quantize((s - mean) / std, precision, bit_width);
.map(|&s| {
let quantized = quantize(s, precision, bit_width);
let mut lsb = client_key
.encrypt(quantized & (1 << ((nb_blocks << 1) - 1)))
.into_blocks(); // Get LSBs
Expand All @@ -97,70 +81,59 @@ fn main() {
.enumerate()
.map(|(cnt, sample)| {
let start = Instant::now();
let probabilities = weights_int
.iter()
.zip(bias_int.iter())
// .par_iter()
// .zip(bias_int.par_iter())
.map(|(model, &bias)| {
let mut prediction =
server_key.create_trivial_radix(bias, (nb_blocks << 1).into());
for (s, &weight) in sample.iter_mut().zip(model.iter()) {
let ct_prod = server_key.smart_scalar_mul(s, weight);
prediction = server_key.unchecked_add(&ct_prod, &prediction);
}
// Truncate
let prediction_blocks = &prediction.clone().into_blocks()
[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec());
// For some reason, the truncation is off by 1...
let prediction_msb = server_key.unchecked_scalar_add(&prediction_msb, 1);
// Keyswitch and Bootstrap
let lut_gen_start = Instant::now();
let exp_lut_lsb = wopbs_key
.generate_lut_radix(&prediction_msb, |x: u64| eval_exp(x, &lut_lsb));
let exp_lut_msb = wopbs_key
.generate_lut_radix(&prediction_msb, |x: u64| eval_exp(x, &lut_msb));
println!(
"LUT generation done in {:?} sec.",
lut_gen_start.elapsed().as_secs_f64()
);
prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let (activation_lsb, activation_msb) = rayon::join(
|| {
let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb);
wopbs_key.keyswitch_to_pbs_params(&activation_lsb)
},
|| {
let activation_msb = wopbs_key.wopbs(&prediction, &exp_lut_msb);
wopbs_key.keyswitch_to_pbs_params(&activation_msb)
},
);
let mut lsb_blocks = activation_lsb.clone().into_blocks();
let msb_blocks = activation_msb.clone().into_blocks();
lsb_blocks.extend(msb_blocks);
RadixCiphertext::from_blocks(lsb_blocks)
})
.collect::<Vec<_>>();

let mut prediction = server_key.create_trivial_radix(bias_int, (nb_blocks << 1).into());
for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) {
let ct_prod = server_key.smart_scalar_mul(s, weight);
prediction = server_key.unchecked_add(&ct_prod, &prediction);
}
// Truncate
let prediction_blocks = &prediction.clone().into_blocks()
[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec());
// For some reason, the truncation is off by 1...
let prediction_msb = server_key.unchecked_scalar_add(&prediction_msb, 1);
// Keyswitch and Bootstrap
let lut_gen_start = Instant::now();
let exp_lut_lsb =
wopbs_key.generate_lut_radix(&prediction_msb, |x: u64| eval_exp(x, &lut_lsb));
let exp_lut_msb =
wopbs_key.generate_lut_radix(&prediction_msb, |x: u64| eval_exp(x, &lut_msb));
println!(
"LUT generation done in {:?} sec.",
lut_gen_start.elapsed().as_secs_f64()
);
prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let (activation_lsb, activation_msb) = rayon::join(
|| {
let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb);
wopbs_key.keyswitch_to_pbs_params(&activation_lsb)
},
|| {
let activation_msb = wopbs_key.wopbs(&prediction, &exp_lut_msb);
wopbs_key.keyswitch_to_pbs_params(&activation_msb)
},
);
let mut lsb_blocks = activation_lsb.clone().into_blocks();
let msb_blocks = activation_msb.clone().into_blocks();
lsb_blocks.extend(msb_blocks);
let probability = RadixCiphertext::from_blocks(lsb_blocks);

println!(
"Finished inference #{:?} in {:?} sec.",
cnt,
start.elapsed().as_secs_f64()
);
probabilities
probability
})
.collect::<Vec<_>>();
// }

// ------- Client side ------- //
let mut total = 0;
for (num, (target, probabilities)) in targets.iter().zip(all_probabilities.iter()).enumerate() {
let ptxt_probabilities = probabilities
.iter()
.map(|p| client_key.decrypt(p))
.collect::<Vec<u64>>();
println!("{:?}", ptxt_probabilities);
let class = argmax(&ptxt_probabilities).unwrap();
for (num, (target, probability)) in targets.iter().zip(all_probabilities.iter()).enumerate() {
let ptxt_probability: u64 = client_key.decrypt(probability);
println!("{:?}", ptxt_probability);
let class = (ptxt_probability > quantize(0.5, precision, bit_width)) as usize;
println!("[{}] predicted {:?}, target {:?}", num, class, target);
if class == *target {
total += 1;
Expand Down

0 comments on commit 0700524

Please sign in to comment.