From 2da7696e6635e671b4def119e828e5060064f56b Mon Sep 17 00:00:00 2001 From: Dimitris Mouris Date: Wed, 28 Feb 2024 14:41:47 -0500 Subject: [PATCH] Update clear and encrypted LR Co-authored-by: Charles Gouert Co-authored-by: Memo --- src/common.rs | 153 +++++++++++++++++++++++++++ src/encrypted_lr.rs | 253 +++++++++++++++++--------------------------- src/float_lr.rs | 67 +----------- src/lib.rs | 1 + src/plain_lr.rs | 163 ++++------------------------ 5 files changed, 274 insertions(+), 363 deletions(-) create mode 100644 src/common.rs create mode 100644 src/lib.rs diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 0000000..4400fdf --- /dev/null +++ b/src/common.rs @@ -0,0 +1,153 @@ +use std::fs::File; + +use rayon::prelude::*; + +pub fn to_signed(x: u64, bit_width: u8) -> i64 { + if x > (1u64 << (bit_width - 1)) { + (x as i128 - (1i128 << bit_width)) as i64 + } else { + x as i64 + } +} + +pub fn from_signed(x: i64, bit_width: u8) -> u64 { + (x as i128).rem_euclid(1i128 << bit_width) as u64 +} + +pub fn quantize(x: f64, precision: u8, bit_width: u8) -> u64 { + from_signed((x * ((1u128 << precision) as f64)) as i64, bit_width) +} + +pub fn quantize_encypted(x: f64, precision: u8) -> u64 { + let mut tmp = (x * ((1 << precision) as f64)) as i32; + tmp += 1 << (precision - 1); + tmp as u64 +} + +pub fn unquantize(x: u64, precision: u8, bit_width: u8) -> f64 { + to_signed(x, bit_width) as f64 / ((1u128 << precision) as f64) +} + +pub fn add(a: u64, b: u64, bit_width: u8) -> u64 { + (a as u128 + b as u128).rem_euclid(1u128 << bit_width) as u64 +} + +pub fn mul(a: u64, b: u64, bit_width: u8) -> u64 { + (a as u128 * b as u128).rem_euclid(1u128 << bit_width) as u64 +} + +pub fn truncate(x: u64, precision: u8, bit_width: u8) -> u64 { + from_signed(to_signed(x, bit_width) / (1i64 << precision), bit_width) +} + +pub fn exponential(x: u64, input_precision: u8, output_precision: u8, bit_width: u8) -> u64 { + let x = to_signed(x, bit_width) as f64; + let shift = (1u128 << input_precision) as f64; + let exp = (x / shift).exp(); + let ret = (exp * ((1u128 << output_precision) as f64)) as u64; + // println!("\t exp {x:?} --> {:?}", &ret); + ret + // ((1.0 / (1.0 + exp)) * (1 << output_precision) as f64) as u64 +} + +pub fn argmax(slice: &[T]) -> Option { + slice + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(index, _)| index) +} + +pub fn load_weights_and_biases() -> (Vec>, Vec) { + let weights_csv = File::open("iris_weights.csv").unwrap(); + let mut reader = csv::Reader::from_reader(weights_csv); + let mut weights = vec![]; + let mut biases = vec![]; + + for result in reader.deserialize() { + let res: Vec = result.expect("a CSV record"); + biases.push(res[0]); + weights.push(res[1..].to_vec()); + } + + (weights, biases) +} + +pub fn quantize_weights_and_biases( + weights: &[Vec], + biases: &[f64], + precision: u8, + bit_width: u8, +) -> (Vec>, Vec) { + let weights_int = weights + .iter() + .map(|row| { + row.iter() + .map(|&w| quantize(w.into(), precision, bit_width)) + .collect::>() + }) + .collect::>(); + let bias_int = biases + .iter() + .map(|&w| quantize(w.into(), precision, bit_width)) + .collect::>(); + + (weights_int, bias_int) +} + +pub fn prepare_iris_dataset() -> (Vec>, Vec) { + let iris = linfa_datasets::iris(); + let mut iris_dataset = vec![]; + let mut targets = vec![]; + + for (sample, target) in iris.sample_iter() { + iris_dataset.push(sample.to_vec()); + targets.push(*target.first().unwrap()); + } + + (iris_dataset, targets) +} + +pub fn means_and_stds(dataset: &[Vec], num_features: usize) -> (Vec, Vec) { + let mut means = vec![0f64; num_features]; + let mut stds = vec![0f64; num_features]; + + for sample in dataset.iter() { + for (feature, s) in sample.iter().enumerate() { + means[feature] += s; + } + } + for mean in means.iter_mut() { + *mean /= dataset.len() as f64; + } + for sample in dataset.iter() { + for (feature, s) in sample.iter().enumerate() { + let dev = s - means[feature]; + stds[feature] += dev * dev; + } + } + for std in stds.iter_mut() { + *std = (*std / dataset.len() as f64).sqrt(); + } + + (means, stds) +} + +pub fn quantize_dataset( + dataset: &Vec>, + means: &Vec, + stds: &Vec, + precision: u8, + bit_width: u8, +) -> Vec> { + dataset + .par_iter() // Use par_iter() for parallel iteration + .map(|sample| { + sample + .par_iter() + .zip(means.par_iter().zip(stds.par_iter())) + .map(|(&s, (mean, std))| quantize((s - mean) / std, precision, bit_width)) + .collect() + }) + .collect() +} diff --git a/src/encrypted_lr.rs b/src/encrypted_lr.rs index 415bf4d..aa46dfd 100644 --- a/src/encrypted_lr.rs +++ b/src/encrypted_lr.rs @@ -1,195 +1,136 @@ -use std::fs::File; +use std::time::Instant; +use fhe_lut::common::*; +use rayon::prelude::*; use tfhe::{ - integer::{gen_keys_radix, wopbs::*, RadixCiphertext}, + integer::{gen_keys_radix, wopbs::*}, shortint::parameters::{ parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, }, }; -fn quantize(x: f32, precision: u8) -> u64 { - let mut tmp = (x * ((1 << precision) as f32)) as i32; - tmp += 1 << (precision - 1); - tmp as u64 -} - -fn sigmoid(x: u64) -> u64 { - let x_f32 = x as f32; - let exp = (-x_f32 / ((1 << 16) as f32)).exp(); - ((1.0 / (1.0 + exp)) * (1 << 8) as f32) as u64 -} - -fn argmax(slice: &[T]) -> Option { - slice - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(index, _)| index) -} - -fn load_weights_and_biases() -> (Vec>, Vec) { - let weights_csv = File::open("iris_weights.csv").unwrap(); - let mut reader = csv::Reader::from_reader(weights_csv); - let mut weights = vec![]; - let mut biases = vec![]; - - for result in reader.deserialize() { - let res: Vec = result.expect("a CSV record"); - biases.push(res[0]); - weights.push(res[1..].to_vec()); - } - - (weights, biases) -} - -fn quantize_weights_and_biases( - weights: &[Vec], - biases: &[f32], - precision: u8, -) -> (Vec>, Vec) { - let weights_int = weights - .iter() - .map(|row| { - row.iter() - .map(|&w| quantize(w, precision)) - .collect::>() - }) - .collect::>(); - let bias_int = biases - .iter() - .map(|&w| quantize(w, precision)) - .collect::>(); - - (weights_int, bias_int) -} - -fn prepare_iris_dataset() -> Vec<(Vec, usize)> { - let iris = linfa_datasets::iris(); - let mut iris_dataset = vec![]; - - for (sample, target) in iris.sample_iter() { - iris_dataset.push((sample.to_vec(), *target.first().unwrap())); - } - - iris_dataset -} - -fn means_and_stds(dataset: &[(Vec, usize)], num_features: usize) -> (Vec, Vec) { - let mut means = vec![0f64; num_features]; - let mut stds = vec![0f64; num_features]; - - for (sample, _) in dataset.iter() { - for (feature, s) in sample.iter().enumerate() { - means[feature] += s; - } - } - for mean in means.iter_mut() { - *mean /= dataset.len() as f64; - } - for (sample, _) in dataset.iter() { - for (feature, s) in sample.iter().enumerate() { - let dev = s - means[feature]; - stds[feature] += dev * dev; - } - } - for std in stds.iter_mut() { - *std = (*std / dataset.len() as f64).sqrt(); - } - - (means, stds) -} - fn main() { // ------- Client side ------- // - let precision = 8; + let bit_width = 16u8; + let precision = bit_width >> 2; + assert!(precision <= bit_width / 2); + // Number of blocks per ciphertext - let nb_blocks = 4; + let nb_blocks = bit_width / 2; + let start = Instant::now(); // Generate radix keys - let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks); + let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into()); // Generate key for PBS (without padding) - let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS); - - // Get message modulus (i.e. max value representable by radix ctxt) - let mut modulus = 1_u64; - for _ in 0..nb_blocks { - modulus *= cks.parameters().message_modulus().0 as u64; - } - println!("Ptxt Modulus: {:?}", modulus); + let wopbs_key = WopbsKey::new_wopbs_key( + &client_key, + &server_key, + &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + ); + println!( + "Key generation done in {:?} sec.", + start.elapsed().as_secs_f64() + ); let (weights, biases) = load_weights_and_biases(); - let (weights_int, bias_int) = quantize_weights_and_biases(&weights, &biases, precision); - let iris_dataset = prepare_iris_dataset(); - let num_features = iris_dataset[0].0.len(); + let (weights_int, bias_int) = + quantize_weights_and_biases(&weights, &biases, precision, bit_width); + let (iris_dataset, targets) = prepare_iris_dataset(); + let num_features = iris_dataset[0].len(); let (means, stds) = means_and_stds(&iris_dataset, num_features); - // TODO(@jimouris): par_iter and map - let mut encrypted_dataset = vec![]; - for (sample, _) in iris_dataset.iter() { - let mut input = vec![]; - for (&s, (mean, std)) in sample.iter().zip(means.iter().zip(stds.iter())) { - let n = (s - mean) / std; - let quantized = quantize(n as f32, precision); - input.push(cks.encrypt(quantized)); - } - encrypted_dataset.push(input); - } + let start = Instant::now(); + let mut encrypted_dataset: Vec> = iris_dataset + .par_iter() // Use par_iter() for parallel iteration + .map(|sample| { + sample + .par_iter() + .zip(means.par_iter().zip(stds.par_iter())) + .map(|(&s, (mean, std))| { + let quantized = quantize((s - mean) / std, precision, bit_width); + client_key.encrypt(quantized) + }) + .collect() + }) + .collect(); + println!( + "Encryption done in {:?} sec.", + start.elapsed().as_secs_f64() + ); // ------- Server side ------- // // Build LUT for Sigmoid - let sigmoid_lut = wopbs_key.generate_lut_radix(&encrypted_dataset[0][0], |x: u64| sigmoid(x)); + let sigmoid_lut = wopbs_key.generate_lut_radix(&encrypted_dataset[0][0], |x: u64| { + exponential(x, 2 * precision, precision, bit_width) + }); - let mut all_probabilities = vec![]; - let mut cnt = 0; - for sample in encrypted_dataset.iter() { - let mut probabilities = vec![]; - for (model, bias) in weights_int.iter().zip(bias_int.iter()) { - let mut prediction: RadixCiphertext = sks.create_trivial_radix(*bias, nb_blocks); - for ((encrypted_value, weight), (_, _)) in sample + let encrypted_dataset_short = encrypted_dataset.get_mut(0..4).unwrap(); + + let all_probabilities = encrypted_dataset_short + .par_iter_mut() + .enumerate() + .map(|(cnt, sample)| { + let start = Instant::now(); + let probabilities = weights_int .iter() - .zip(model.iter()) - .zip(means.iter().zip(stds.iter())) - { - // prediction += weight * encrypted_value; - let ct_prod = sks.unchecked_small_scalar_mul(encrypted_value, *weight); - prediction = sks.unchecked_add(&ct_prod, &prediction); - } - prediction = wopbs_key.keyswitch_to_wopbs_params(&sks, &prediction); - let activation = wopbs_key.wopbs(&prediction, &sigmoid_lut); - prediction = wopbs_key.keyswitch_to_pbs_params(&activation); - probabilities.push(prediction); - } - println!("Finished inference #{:?}", cnt); - all_probabilities.push(probabilities); - cnt += 1; - if cnt == 2 { - break; - } - } + .zip(bias_int.iter()) + // .par_iter() + // .zip(bias_int.par_iter()) + .map(|(model, &bias)| { + let scaled_bias = mul(1 << precision, bias, bit_width); + let mut prediction = + server_key.create_trivial_radix(scaled_bias, nb_blocks.into()); + for (s, &weight) in sample.iter_mut().zip(model.iter()) { + let mut d: u64 = client_key.decrypt(&s); + println!("s: {:?}", d); + println!("weight: {:?}", weight); + let ct_prod = server_key.smart_scalar_mul(s, weight); + d = client_key.decrypt(&ct_prod); + println!("mul result: {:?}", d); + prediction = server_key.unchecked_add(&ct_prod, &prediction); + // FIXME: DEBUG + d = client_key.decrypt(&prediction); + println!("MAC result: {:?}", d); + println!(); + } + println!(); + prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction); + let activation = wopbs_key.wopbs(&prediction, &sigmoid_lut); + + let probability = wopbs_key.keyswitch_to_pbs_params(&activation); + let d: u64 = client_key.decrypt(&probability); + println!("Exponential result: {:?}", d); + + probability + }) + .collect::>(); + println!( + "Finished inference #{:?} in {:?} sec.", + cnt, + start.elapsed().as_secs_f64() + ); + probabilities + }) + .collect::>(); + // } // ------- Client side ------- // let mut total = 0; - for (num, ((_, target), probabilities)) in iris_dataset - .iter() - .zip(all_probabilities.iter()) - .enumerate() - { + for (num, (target, probabilities)) in targets.iter().zip(all_probabilities.iter()).enumerate() { let ptxt_probabilities = probabilities .iter() - .map(|p| cks.decrypt(p)) + .map(|p| client_key.decrypt(p)) .collect::>(); + println!("{:?}", ptxt_probabilities); let class = argmax(&ptxt_probabilities).unwrap(); println!("[{}] predicted {:?}, target {:?}", num, class, target); if class == *target { total += 1; } - if num == 2 { - break; - } } - let accuracy = (total as f32 / iris_dataset.len() as f32) * 100.0; + let accuracy = (total as f32 / encrypted_dataset_short.len() as f32) * 100.0; println!("Accuracy {accuracy}%"); } diff --git a/src/float_lr.rs b/src/float_lr.rs index 50b3802..7e7ea1e 100644 --- a/src/float_lr.rs +++ b/src/float_lr.rs @@ -1,73 +1,14 @@ -use std::fs::File; - -fn argmax(slice: &[T]) -> Option { - slice - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(index, _)| index) -} - -fn load_weights_and_biases() -> (Vec>, Vec) { - let weights_csv = File::open("iris_weights.csv").unwrap(); - let mut reader = csv::Reader::from_reader(weights_csv); - let mut weights = vec![]; - let mut biases = vec![]; - - for result in reader.deserialize() { - let res: Vec = result.expect("a CSV record"); - biases.push(res[0]); - weights.push(res[1..].to_vec()); - } - - (weights, biases) -} - -fn prepare_iris_dataset() -> Vec<(Vec, usize)> { - let iris = linfa_datasets::iris(); - let mut iris_dataset = vec![]; - - for (sample, target) in iris.sample_iter() { - iris_dataset.push((sample.to_vec(), *target.first().unwrap())); - } - - iris_dataset -} - -fn means_and_stds(dataset: &[(Vec, usize)], num_features: usize) -> (Vec, Vec) { - let mut means = vec![0f64; num_features]; - let mut stds = vec![0f64; num_features]; - - for (sample, _) in dataset.iter() { - for (feature, s) in sample.iter().enumerate() { - means[feature] += s; - } - } - for mean in means.iter_mut() { - *mean /= dataset.len() as f64; - } - for (sample, _) in dataset.iter() { - for (feature, s) in sample.iter().enumerate() { - let dev = s - means[feature]; - stds[feature] += dev * dev; - } - } - for std in stds.iter_mut() { - *std = (*std / dataset.len() as f64).sqrt(); - } - - (means, stds) -} +use fhe_lut::common::*; fn main() { let (weights, biases) = load_weights_and_biases(); - let iris_dataset = prepare_iris_dataset(); - let num_features = iris_dataset[0].0.len(); + let (iris_dataset, targets) = prepare_iris_dataset(); + let num_features = iris_dataset[0].len(); let (means, stds) = means_and_stds(&iris_dataset, num_features); let mut total = 0; - for (num, (sample, target)) in iris_dataset.iter().enumerate() { + for (num, (sample, target)) in (iris_dataset.iter().zip(targets.iter())).enumerate() { let mut probabilities = vec![]; let mut sum_p = 0f64; for (model, bias) in weights.iter().zip(biases.iter()) { diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..34994bf --- /dev/null +++ b/src/lib.rs @@ -0,0 +1 @@ +pub mod common; diff --git a/src/plain_lr.rs b/src/plain_lr.rs index b72f323..9be1952 100644 --- a/src/plain_lr.rs +++ b/src/plain_lr.rs @@ -1,169 +1,43 @@ -use std::fs::File; - -use debug_print::debug_println; +// use debug_print::debug_println; +use fhe_lut::common::*; use rayon::prelude::*; -fn to_signed(x: u64) -> i64 { - if x > (1u64 << 63) { - (x as i128 - (1i128 << 64)) as i64 - } else { - x as i64 - } -} - -fn from_signed(x: i64) -> u64 { - (x as i128).rem_euclid(1i128 << 64) as u64 -} - -fn quantize(x: f64, precision: u8) -> u64 { - from_signed((x * ((1u128 << precision) as f64)) as i64) -} - -fn unquantize(x: u64, precision: u8) -> f64 { - to_signed(x) as f64 / ((1u128 << precision) as f64) -} - -fn add(a: u64, b: u64) -> u64 { - (a as u128 + b as u128).rem_euclid(1u128 << 64) as u64 -} - -fn mul(a: u64, b: u64) -> u64 { - (a as u128 * b as u128).rem_euclid(1u128 << 64) as u64 -} - -fn truncate(x: u64, precision: u8) -> u64 { - from_signed(to_signed(x) / (1i64 << precision)) -} - -fn sigmoid(x: u64, input_precision: u8, output_precision: u8) -> u64 { - let x = to_signed(x) as f64; - let shift = (1u128 << input_precision) as f64; - let exp = (x / shift).exp(); - (exp * ((1u128 << output_precision) as f64)) as u64 - // ((1.0 / (1.0 + exp)) * (1 << output_precision) as f64) as u64 -} - -fn argmax(slice: &[T]) -> Option { - slice - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) - .map(|(index, _)| index) -} - -fn load_weights_and_biases() -> (Vec>, Vec) { - let weights_csv = File::open("iris_weights.csv").unwrap(); - let mut reader = csv::Reader::from_reader(weights_csv); - let mut weights = vec![]; - let mut biases = vec![]; - - for result in reader.deserialize() { - let res: Vec = result.expect("a CSV record"); - biases.push(res[0]); - weights.push(res[1..].to_vec()); - } - - (weights, biases) -} - -fn quantize_weights_and_biases( - weights: &[Vec], - biases: &[f64], - precision: u8, -) -> (Vec>, Vec) { - let weights_int = weights - .iter() - .map(|row| { - row.iter() - .map(|&w| quantize(w.into(), precision)) - .collect::>() - }) - .collect::>(); - let bias_int = biases - .iter() - .map(|&w| quantize(w.into(), precision)) - .collect::>(); - - (weights_int, bias_int) -} - -fn prepare_iris_dataset() -> (Vec>, Vec) { - let iris = linfa_datasets::iris(); - let mut iris_dataset = vec![]; - let mut targets = vec![]; - - for (sample, target) in iris.sample_iter() { - iris_dataset.push(sample.to_vec()); - targets.push(*target.first().unwrap()); - } - - (iris_dataset, targets) -} - -fn means_and_stds(dataset: &[Vec], num_features: usize) -> (Vec, Vec) { - let mut means = vec![0f64; num_features]; - let mut stds = vec![0f64; num_features]; - - for sample in dataset.iter() { - for (feature, s) in sample.iter().enumerate() { - means[feature] += s; - } - } - for mean in means.iter_mut() { - *mean /= dataset.len() as f64; - } - for sample in dataset.iter() { - for (feature, s) in sample.iter().enumerate() { - let dev = s - means[feature]; - stds[feature] += dev * dev; - } - } - for std in stds.iter_mut() { - *std = (*std / dataset.len() as f64).sqrt(); - } - - (means, stds) -} - fn main() { - let precision = 6; + let bit_width = 8u8; + let precision = bit_width >> 2; + let (weights, biases) = load_weights_and_biases(); - let (weights_int, bias_int) = quantize_weights_and_biases(&weights, &biases, precision); + let (weights_int, bias_int) = + quantize_weights_and_biases(&weights, &biases, precision, bit_width); let (iris_dataset, targets) = prepare_iris_dataset(); let num_features = iris_dataset[0].len(); let (means, stds) = means_and_stds(&iris_dataset, num_features); - let quantized_dataset: Vec> = iris_dataset - .par_iter() // Use par_iter() for parallel iteration - .map(|sample| { - sample - .par_iter() - .zip(means.par_iter().zip(stds.par_iter())) - .map(|(&s, (mean, std))| quantize((s - mean) / std, precision)) - .collect() - }) - .collect(); + let quantized_dataset = quantize_dataset(&iris_dataset, &means, &stds, precision, bit_width); let mut total = 0; for (target, sample) in targets.iter().zip(quantized_dataset.iter()) { // Server computation - let probabilities: Vec<_> = weights_int + let probabilities = weights_int .par_iter() .zip(bias_int.par_iter()) .map(|(model, &bias)| { - let mut prediction = bias; + let mut prediction = (1 << precision) * bias; for (&s, &w) in sample.iter().zip(model.iter()) { - let cur = truncate(mul(w, s), precision); - prediction = add(prediction, cur); + println!("s: {:?}", s); + println!("weight: {:?}", w); + prediction = add(prediction, mul(w, s, bit_width), bit_width); + println!("MAC result: {:?}", prediction); } - sigmoid(prediction, precision, precision) + println!(); + exponential(prediction, 2 * precision, precision, bit_width) }) - .collect(); + .collect::>(); // Client computation let class = argmax(&probabilities).unwrap(); - debug_println!("predicted {class:?}, target {target:?}"); + println!("predicted {class:?}, target {target:?}"); if class == *target { total += 1; } @@ -171,4 +45,5 @@ fn main() { let accuracy = (total as f64 / iris_dataset.len() as f64) * 100.0; println!("Accuracy {accuracy}%"); + println!("precision: {precision}, bit_width: {bit_width}"); }