From 5cf803d40b61650c3b246300db5bd60141aad7a7 Mon Sep 17 00:00:00 2001 From: cgouert Date: Sat, 25 May 2024 18:15:57 -0400 Subject: [PATCH] Add all variants for LR and Euclidean distance --- Cargo.toml | 16 ++++ data/bior_lut_div_16.json | 1 + data/bior_lut_div_16_2.json | 1 + src/common.rs | 2 +- src/euclidean.rs | 69 ++++++++------ src/euclidean_bior.rs | 174 ++++++++++++++++++++++++++++++++++ src/euclidean_haar.rs | 176 ++++++++++++++++++++++++++++++++++ src/euclidean_quantized.rs | 140 +++++++++++++++++++++++++++ src/lr_lut.rs | 9 +- src/lr_lut_quantized.rs | 184 ++++++++++++++++++++++++++++++++++++ src/lr_ptxt.rs | 5 - 11 files changed, 738 insertions(+), 39 deletions(-) create mode 100644 data/bior_lut_div_16.json create mode 100644 data/bior_lut_div_16_2.json create mode 100644 src/euclidean_bior.rs create mode 100644 src/euclidean_haar.rs create mode 100644 src/euclidean_quantized.rs create mode 100644 src/lr_lut_quantized.rs diff --git a/Cargo.toml b/Cargo.toml index dc25740..b18a9a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,10 @@ path = "src/lr_ptxt_float.rs" name = "lr_lut" path = "src/lr_lut.rs" +[[bin]] +name = "lr_lut_quantized" +path = "src/lr_lut_quantized.rs" + [[bin]] name = "lr_haar_ptxt" path = "src/lr_haar_ptxt.rs" @@ -79,6 +83,18 @@ path = "src/correlation_lut.rs" name = "euclidean" path = "src/euclidean.rs" +[[bin]] +name = "euclidean_quantized" +path = "src/euclidean_quantized.rs" + +[[bin]] +name = "euclidean_haar" +path = "src/euclidean_haar.rs" + +[[bin]] +name = "euclidean_bior" +path = "src/euclidean_bior.rs" + # Primitive Timings [[bin]] diff --git a/data/bior_lut_div_16.json b/data/bior_lut_div_16.json new file mode 100644 index 0000000..801ddcf --- /dev/null +++ b/data/bior_lut_div_16.json @@ -0,0 +1 @@ +{"0": 0, "1": 5, "2": 10, "3": 16, "4": 21, "5": 26, "6": 32, "7": 37, "8": 42, "9": 48, "10": 53, "11": 58, "12": 64, "13": 69, "14": 74, "15": 80, "16": 85, "17": 90, "18": 96, "19": 101, "20": 106, "21": 112, "22": 117, "23": 122, "24": 128, "25": 133, "26": 138, "27": 144, "28": 149, "29": 154, "30": 160, "31": 165, "32": 170, "33": 176, "34": 181, "35": 186, "36": 192, "37": 197, "38": 202, "39": 208, "40": 213, "41": 218, "42": 224, "43": 229, "44": 234, "45": 240, "46": 245, "47": 250, "48": 0, "49": 5, "50": 10, "51": 16, "52": 21, "53": 26, "54": 32, "55": 37, "56": 42, "57": 48, "58": 53, "59": 58, "60": 64, "61": 69, "62": 74, "63": 80, "64": 85, "65": 90, "66": 96, "67": 101, "68": 106, "69": 112, "70": 117, "71": 122, "72": 128, "73": 133, "74": 138, "75": 144, "76": 149, "77": 154, "78": 160, "79": 165, "80": 170, "81": 176, "82": 181, "83": 186, "84": 192, "85": 197, "86": 202, "87": 208, "88": 213, "89": 218, "90": 224, "91": 229, "92": 234, "93": 240, "94": 245, "95": 250, "96": 0, "97": 5, "98": 10, "99": 16, "100": 21, "101": 26, "102": 32, "103": 37, "104": 42, "105": 48, "106": 53, "107": 58, "108": 64, "109": 69, "110": 74, "111": 80, "112": 85, "113": 90, "114": 96, "115": 101, "116": 106, "117": 112, "118": 117, "119": 122, "120": 128, "121": 133, "122": 138, "123": 144, "124": 149, "125": 154, "126": 160, "127": 165, "128": 170, "129": 176, "130": 181, "131": 186, "132": 192, "133": 197, "134": 202, "135": 208, "136": 213, "137": 218, "138": 224, "139": 229, "140": 234, "141": 240, "142": 245, "143": 250, "144": 0, "145": 5, "146": 10, "147": 16, "148": 21, "149": 26, "150": 32, "151": 37, "152": 42, "153": 48, "154": 53, "155": 58, "156": 64, "157": 69, "158": 74, "159": 80, "160": 85, "161": 90, "162": 96, "163": 101, "164": 106, "165": 112, "166": 117, "167": 122, "168": 128, "169": 133, "170": 138, "171": 144, "172": 149, "173": 154, "174": 160, "175": 165, "176": 170, "177": 176, "178": 181, "179": 186, "180": 192, "181": 197, "182": 202, "183": 208, "184": 213, "185": 218, "186": 224, "187": 229, "188": 234, "189": 240, "190": 245, "191": 250, "192": 0, "193": 5, "194": 10, "195": 16, "196": 21, "197": 26, "198": 32, "199": 37, "200": 42, "201": 48, "202": 53, "203": 58, "204": 64, "205": 69, "206": 74, "207": 80, "208": 85, "209": 90, "210": 96, "211": 101, "212": 106, "213": 112, "214": 117, "215": 122, "216": 128, "217": 133, "218": 138, "219": 144, "220": 149, "221": 154, "222": 160, "223": 165, "224": 170, "225": 176, "226": 181, "227": 186, "228": 192, "229": 197, "230": 202, "231": 208, "232": 213, "233": 218, "234": 224, "235": 229, "236": 234, "237": 240, "238": 245, "239": 250, "240": 0, "241": 5, "242": 10, "243": 16, "244": 21, "245": 26, "246": 32, "247": 37, "248": 42, "249": 48, "250": 53, "251": 58, "252": 64, "253": 69, "254": 74, "255": 80} \ No newline at end of file diff --git a/data/bior_lut_div_16_2.json b/data/bior_lut_div_16_2.json new file mode 100644 index 0000000..8ab8cd2 --- /dev/null +++ b/data/bior_lut_div_16_2.json @@ -0,0 +1 @@ +{"0": 5, "1": 10, "2": 16, "3": 21, "4": 26, "5": 32, "6": 37, "7": 42, "8": 48, "9": 53, "10": 58, "11": 64, "12": 69, "13": 74, "14": 80, "15": 85, "16": 90, "17": 96, "18": 101, "19": 106, "20": 112, "21": 117, "22": 122, "23": 128, "24": 133, "25": 138, "26": 144, "27": 149, "28": 154, "29": 160, "30": 165, "31": 170, "32": 176, "33": 181, "34": 186, "35": 192, "36": 197, "37": 202, "38": 208, "39": 213, "40": 218, "41": 224, "42": 229, "43": 234, "44": 240, "45": 245, "46": 250, "47": 0, "48": 5, "49": 10, "50": 16, "51": 21, "52": 26, "53": 32, "54": 37, "55": 42, "56": 48, "57": 53, "58": 58, "59": 64, "60": 69, "61": 74, "62": 80, "63": 85, "64": 90, "65": 96, "66": 101, "67": 106, "68": 112, "69": 117, "70": 122, "71": 128, "72": 133, "73": 138, "74": 144, "75": 149, "76": 154, "77": 160, "78": 165, "79": 170, "80": 176, "81": 181, "82": 186, "83": 192, "84": 197, "85": 202, "86": 208, "87": 213, "88": 218, "89": 224, "90": 229, "91": 234, "92": 240, "93": 245, "94": 250, "95": 0, "96": 5, "97": 10, "98": 16, "99": 21, "100": 26, "101": 32, "102": 37, "103": 42, "104": 48, "105": 53, "106": 58, "107": 64, "108": 69, "109": 74, "110": 80, "111": 85, "112": 90, "113": 96, "114": 101, "115": 106, "116": 112, "117": 117, "118": 122, "119": 128, "120": 133, "121": 138, "122": 144, "123": 149, "124": 154, "125": 160, "126": 165, "127": 170, "128": 176, "129": 181, "130": 186, "131": 192, "132": 197, "133": 202, "134": 208, "135": 213, "136": 218, "137": 224, "138": 229, "139": 234, "140": 240, "141": 245, "142": 250, "143": 0, "144": 5, "145": 10, "146": 16, "147": 21, "148": 26, "149": 32, "150": 37, "151": 42, "152": 48, "153": 53, "154": 58, "155": 64, "156": 69, "157": 74, "158": 80, "159": 85, "160": 90, "161": 96, "162": 101, "163": 106, "164": 112, "165": 117, "166": 122, "167": 128, "168": 133, "169": 138, "170": 144, "171": 149, "172": 154, "173": 160, "174": 165, "175": 170, "176": 176, "177": 181, "178": 186, "179": 192, "180": 197, "181": 202, "182": 208, "183": 213, "184": 218, "185": 224, "186": 229, "187": 234, "188": 240, "189": 245, "190": 250, "191": 0, "192": 5, "193": 10, "194": 16, "195": 21, "196": 26, "197": 32, "198": 37, "199": 42, "200": 48, "201": 53, "202": 58, "203": 64, "204": 69, "205": 74, "206": 80, "207": 85, "208": 90, "209": 96, "210": 101, "211": 106, "212": 112, "213": 117, "214": 122, "215": 128, "216": 133, "217": 138, "218": 144, "219": 149, "220": 154, "221": 160, "222": 165, "223": 170, "224": 176, "225": 181, "226": 186, "227": 192, "228": 197, "229": 202, "230": 208, "231": 213, "232": 218, "233": 224, "234": 229, "235": 234, "236": 240, "237": 245, "238": 250, "239": 0, "240": 5, "241": 10, "242": 16, "243": 21, "244": 26, "245": 32, "246": 37, "247": 42, "248": 48, "249": 53, "250": 58, "251": 64, "252": 69, "253": 74, "254": 80, "255": 85} \ No newline at end of file diff --git a/src/common.rs b/src/common.rs index 747ed9b..e66a626 100644 --- a/src/common.rs +++ b/src/common.rs @@ -248,7 +248,7 @@ pub fn quantized_table( (lsb, msb) } -fn eval_lut(x: u64, lut_map: &Vec) -> u64 { +pub fn eval_lut(x: u64, lut_map: &Vec) -> u64 { lut_map[x as usize] } diff --git a/src/euclidean.rs b/src/euclidean.rs index b86eca4..8ddd049 100644 --- a/src/euclidean.rs +++ b/src/euclidean.rs @@ -2,27 +2,30 @@ use std::time::Instant; use num_integer::Roots; use rayon::prelude::*; -use ripple::common; +use ripple::common::{self, ct_lut_eval_no_gen}; use tfhe::{ - integer::{gen_keys_radix, wopbs::*, RadixCiphertext}, + integer::{ + gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, + }, shortint::parameters::{ - parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, PARAM_MESSAGE_2_CARRY_2_KS_PBS, }, }; /// d(x, y) = sqrt( sum((xi - yi)^2) ) -fn euclidean(x: &[u32], y: &[u32]) -> f32 { - x.iter() - .zip(y.iter()) - .map(|(&xi, &yi)| (xi - yi).pow(2) as f32) - .sum::() - .sqrt() -} +// fn euclidean(x: &[u32], y: &[u32]) -> f32 { +// x.iter() +// .zip(y.iter()) +// .map(|(&xi, &yi)| (xi - yi).pow(2) as f32) +// .sum::() +// .sqrt() +// } fn main() { let data = common::read_csv("data/euclidean.csv"); let xs = &data[0]; + let num_iter = 3; // ------- Client side ------- // let bit_width = 16; @@ -59,28 +62,42 @@ fn main() { ); // ------- Server side ------- // - // TODO: Move LUT gens up here + let lut_gen_start = Instant::now(); + println!("Generating LUT."); + let mut dummy: RadixCiphertext = server_key.create_trivial_radix(2_u64, (nb_blocks).into()); + dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy); + let mut dummy_blocks = dummy.clone().into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + dummy = RadixCiphertext::from_blocks(dummy_blocks); + let sqrt_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| x.sqrt()); + let div_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| x / (num_iter as u64)); + println!( + "LUT generation done in {:?} sec.", + lut_gen_start.elapsed().as_secs_f64() + ); - let num_iter = 3; assert!( num_iter < data.len(), "Not enough columns in CSV for that many iterations" ); - let mut sum_dists = (1..num_iter + 1) + let bench_start = Instant::now(); + let sum_dists = (1..num_iter + 1) .into_par_iter() .map(|i| { let ys = &data[i]; - let distance = euclidean(xs, ys); - println!("{}) Ptxt Euclidean distance: {}", i, distance); + // let distance = euclidean(xs, ys); + // println!("{}) Ptxt Euclidean distance: {}", i, distance); // Compute the encrypted euclidean distance let start = Instant::now(); println!("{}) Starting computing Squared Euclidean distance", i); - let mut euclid_squared_enc = xs_enc + let euclid_squared_enc = xs_enc .iter() .zip(ys.iter()) .map(|(x_enc, &y)| { @@ -96,13 +113,10 @@ fn main() { i, start.elapsed().as_secs_f64() ); - + // println!("euclid_squared_enc degree: {:?}", euclid_squared_enc.blocks()[0].degree); println!("{}) Starting computing square root", i); - let sqrt_lut = wopbs_key.generate_lut_radix(&euclid_squared_enc, |x: u64| x.sqrt()); - euclid_squared_enc = - wopbs_key.keyswitch_to_wopbs_params(&server_key, &euclid_squared_enc); - let mut distance_enc = wopbs_key.wopbs(&euclid_squared_enc, &sqrt_lut); - distance_enc = wopbs_key.keyswitch_to_pbs_params(&distance_enc); + let distance_enc = + ct_lut_eval_no_gen(euclid_squared_enc, &wopbs_key, &server_key, &sqrt_lut); println!( "{}) Finished computing square root in {:?} sec.", i, @@ -118,11 +132,12 @@ fn main() { |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), ); - let div_lut = wopbs_key.generate_lut_radix(&sum_dists, |x: u64| x / (num_iter as u64)); - sum_dists = wopbs_key.keyswitch_to_wopbs_params(&server_key, &sum_dists); - let mut dists_mean_enc = wopbs_key.wopbs(&sum_dists, &div_lut); - dists_mean_enc = wopbs_key.keyswitch_to_pbs_params(&dists_mean_enc); - + // println!("sum_dists degree: {:?}", sum_dists.blocks()[0].degree); + let dists_mean_enc = ct_lut_eval_no_gen(sum_dists, &wopbs_key, &server_key, &div_lut); + println!( + "Finished computing everything in {:?} sec.", + bench_start.elapsed().as_secs_f64() + ); // ------- Client side ------- // let mean_distance: u64 = client_key.decrypt(&dists_mean_enc); println!( diff --git a/src/euclidean_bior.rs b/src/euclidean_bior.rs new file mode 100644 index 0000000..4c3aa15 --- /dev/null +++ b/src/euclidean_bior.rs @@ -0,0 +1,174 @@ +use std::time::Instant; + +use rayon::prelude::*; +use ripple::common::*; +use tfhe::{ + integer::{ + gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, + }, + shortint::parameters::{ + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + }, +}; + +fn main() { + let data = read_csv("data/euclidean.csv"); + let xs = &data[0]; + let num_iter = 3; + + // ------- Client side ------- // + let bit_width = 16; + let precision = 12; + let j = 8; + + let (sqrt_lut_lsb, sqrt_lut_msb) = bior("data/bior_lut_sqrt_16.json", j as u8, bit_width); + let (sqrt_lut_lsb_2, sqrt_lut_msb_2) = bior("data/bior_lut_sqrt_16_2.json", j as u8, bit_width); + let (div_lut_lsb, div_lut_msb) = bior("data/bior_lut_div_16.json", j as u8, bit_width); + let (div_lut_lsb_2, div_lut_msb_2) = bior("data/bior_lut_div_16_2.json", j as u8, bit_width); + + let sqrt_luts = vec![ + &sqrt_lut_lsb, + &sqrt_lut_lsb_2, + &sqrt_lut_msb, + &sqrt_lut_msb_2, + ]; + let div_luts = vec![&div_lut_lsb, &div_lut_lsb_2, &div_lut_msb, &div_lut_msb_2]; + + // Number of blocks per ciphertext + let nb_blocks = bit_width / 2; + println!( + "Number of blocks for the radix decomposition: {:?}", + nb_blocks + ); + + let start = Instant::now(); + // Generate radix keys + let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into()); + // Generate key for PBS (without padding) + let wopbs_key = WopbsKey::new_wopbs_key( + &client_key, + &server_key, + &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + ); + println!( + "Key generation done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + let start = Instant::now(); + let xs_enc: Vec<_> = xs + .par_iter() // Use par_iter() for parallel iteration + .map(|&x| client_key.encrypt(quantize(x as f64, precision, bit_width as u8))) + .collect(); + println!( + "Encryption done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + // ------- Server side ------- // + let lut_gen_start = Instant::now(); + println!("Generating LUT."); + let dummy: RadixCiphertext = server_key.create_trivial_radix(0_u64, j >> 1); + let mut dummy_blocks = dummy.into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + let dummy = RadixCiphertext::from_blocks(dummy_blocks); + let dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy); + let encoded_sqrt_luts = sqrt_luts + .iter() + .map(|lut| wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &lut.to_vec()))) + .collect::>(); + let encoded_div_luts = div_luts + .iter() + .map(|lut| wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &lut.to_vec()))) + .collect::>(); + println!( + "LUT generation done in {:?} sec.", + lut_gen_start.elapsed().as_secs_f64() + ); + + assert!( + num_iter < data.len(), + "Not enough columns in CSV for that many iterations" + ); + + let bench_start = Instant::now(); + let sum_dists = (1..num_iter + 1) + .into_par_iter() + .map(|i| { + let ys = &data[i]; + // Compute the encrypted euclidean distance + + let start = Instant::now(); + println!("{}) Starting computing Squared Euclidean distance", i); + + let euclid_squared_enc = xs_enc + .iter() + .zip(ys.iter()) + .map(|(x_enc, &y)| { + let diff = server_key.scalar_sub_parallelized(x_enc, y); + server_key.mul_parallelized(&diff, &diff) + }) + .fold( + server_key.create_trivial_radix(0_u64, nb_blocks.into()), + |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), + ); + println!( + "{}) Finished computing Squared Euclidean distance in {:?} sec.", + i, + start.elapsed().as_secs_f64() + ); + // println!("euclid_squared_enc degree: {:?}", euclid_squared_enc.blocks()[0].degree); + println!("{}) Starting computing square root", i); + let distance_enc = ct_lut_eval_bior_no_gen( + euclid_squared_enc, + bit_width.into(), + nb_blocks.into(), + j, + &wopbs_key, + 0_i32, + &server_key, + &encoded_sqrt_luts, + ); + println!( + "{}) Finished computing square root in {:?} sec.", + i, + start.elapsed().as_secs_f64() + ); + + distance_enc + }) + .collect::>() + .into_iter() + .fold( + server_key.create_trivial_radix(0_u64, nb_blocks.into()), + |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), + ); + + // println!("sum_dists degree: {:?}", sum_dists.blocks()[0].degree); + let dists_mean_enc = ct_lut_eval_bior_no_gen( + sum_dists, + bit_width.into(), + nb_blocks.into(), + j, + &wopbs_key, + 0_i32, + &server_key, + &encoded_div_luts, + ); + println!( + "Finished computing everything in {:?} sec.", + bench_start.elapsed().as_secs_f64() + ); + + // ------- Client side ------- // + let mean_distance: u64 = client_key.decrypt(&dists_mean_enc); + let mean_distance: f64 = unquantize(mean_distance, precision, bit_width as u8); + + println!( + "Mean of {} Euclidean distances: {}", + num_iter, mean_distance + ); +} diff --git a/src/euclidean_haar.rs b/src/euclidean_haar.rs new file mode 100644 index 0000000..42c7641 --- /dev/null +++ b/src/euclidean_haar.rs @@ -0,0 +1,176 @@ +use std::time::Instant; + +use rayon::prelude::*; +use ripple::common::*; +use tfhe::{ + integer::{ + gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, + }, + shortint::parameters::{ + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + }, +}; + +fn my_sqrt(value: f64) -> f64 { + value.sqrt() +} + +fn my_div(value: f64) -> f64 { + value / 3_f64 +} + +fn main() { + let data = read_csv("data/euclidean.csv"); + let xs = &data[0]; + let num_iter = 3; + + // ------- Client side ------- // + let bit_width = 16; + let precision = 6; + // Number of blocks per ciphertext + let nb_blocks = bit_width / 2; + println!( + "Number of blocks for the radix decomposition: {:?}", + nb_blocks + ); + + let start = Instant::now(); + // Generate radix keys + let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks); + // Generate key for PBS (without padding) + let wopbs_key = WopbsKey::new_wopbs_key( + &client_key, + &server_key, + &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + ); + println!( + "Key generation done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + let start = Instant::now(); + let xs_enc: Vec<_> = xs + .par_iter() // Use par_iter() for parallel iteration + .map(|&x| client_key.encrypt(quantize(x as f64, precision, bit_width as u8))) + .collect(); + println!( + "Encryption done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + // ------- Server side ------- // + let lut_gen_start = Instant::now(); + println!("Generating LUT."); + let dummy: RadixCiphertext = server_key.create_trivial_radix(0_u64, nb_blocks >> 1); + let mut dummy_blocks = dummy.into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + let dummy = RadixCiphertext::from_blocks(dummy_blocks); + let dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy); + + let (haar_lsb, haar_msb) = haar( + precision, + precision, + bit_width as u8, + bit_width as u8, + &my_sqrt, + ); + let haar_lsb_lut_sqrt = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_lsb)); + let haar_msb_lut_sqrt = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_msb)); + let (haar_lsb, haar_msb) = haar( + precision, + precision, + bit_width as u8, + bit_width as u8, + &my_div, + ); + let haar_lsb_lut_div = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_lsb)); + let haar_msb_lut_div = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_msb)); + + println!( + "LUT generation done in {:?} sec.", + lut_gen_start.elapsed().as_secs_f64() + ); + + assert!( + num_iter < data.len(), + "Not enough columns in CSV for that many iterations" + ); + + let bench_start = Instant::now(); + let sum_dists = (1..num_iter + 1) + .into_par_iter() + .map(|i| { + let ys = &data[i]; + + // Compute the encrypted euclidean distance + + let start = Instant::now(); + println!("{}) Starting computing Squared Euclidean distance", i); + + let euclid_squared_enc = xs_enc + .iter() + .zip(ys.iter()) + .map(|(x_enc, &y)| { + let diff = server_key.scalar_sub_parallelized(x_enc, y); + server_key.mul_parallelized(&diff, &diff) + }) + .fold( + server_key.create_trivial_radix(0_u64, nb_blocks), + |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), + ); + println!( + "{}) Finished computing Squared Euclidean distance in {:?} sec.", + i, + start.elapsed().as_secs_f64() + ); + // println!("euclid_squared_enc degree: {:?}", euclid_squared_enc.blocks()[0].degree); + println!("{}) Starting computing square root", i); + let distance_enc = ct_lut_eval_haar_no_gen( + euclid_squared_enc, + nb_blocks, + &wopbs_key, + &server_key, + &haar_lsb_lut_sqrt, + &haar_msb_lut_sqrt, + ); + println!( + "{}) Finished computing square root in {:?} sec.", + i, + start.elapsed().as_secs_f64() + ); + + distance_enc + }) + .collect::>() + .into_iter() + .fold( + server_key.create_trivial_radix(0_u64, nb_blocks), + |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), + ); + + // println!("sum_dists degree: {:?}", sum_dists.blocks()[0].degree); + let dists_mean_enc = ct_lut_eval_haar_no_gen( + sum_dists, + nb_blocks, + &wopbs_key, + &server_key, + &haar_lsb_lut_div, + &haar_msb_lut_div, + ); + println!( + "Finished computing everything in {:?} sec.", + bench_start.elapsed().as_secs_f64() + ); + + // ------- Client side ------- // + let mean_distance: u64 = client_key.decrypt(&dists_mean_enc); + let mean_distance: f64 = unquantize(mean_distance, precision, bit_width as u8); + + println!( + "Mean of {} Euclidean distances: {}", + num_iter, mean_distance + ); +} diff --git a/src/euclidean_quantized.rs b/src/euclidean_quantized.rs new file mode 100644 index 0000000..caf5bdf --- /dev/null +++ b/src/euclidean_quantized.rs @@ -0,0 +1,140 @@ +use std::time::Instant; + +use num_integer::Roots; +use rayon::prelude::*; +use ripple::common::{self, ct_lut_eval_quantized_no_gen}; +use tfhe::{ + integer::{ + gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, + }, + shortint::parameters::{ + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + }, +}; + +fn main() { + let data = common::read_csv("data/euclidean.csv"); + let xs = &data[0]; + let num_iter = 3; + + // ------- Client side ------- // + let bit_width = 16; + + // Number of blocks per ciphertext + let nb_blocks = bit_width / 2; + println!( + "Number of blocks for the radix decomposition: {:?}", + nb_blocks + ); + + let start = Instant::now(); + // Generate radix keys + let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks); + // Generate key for PBS (without padding) + let wopbs_key = WopbsKey::new_wopbs_key( + &client_key, + &server_key, + &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + ); + println!( + "Key generation done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + let start = Instant::now(); + let xs_enc: Vec<_> = xs + .par_iter() // Use par_iter() for parallel iteration + .map(|&x| client_key.encrypt(x)) + .collect(); + println!( + "Encryption done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + // ------- Server side ------- // + let lut_gen_start = Instant::now(); + println!("Generating LUT."); + let mut dummy: RadixCiphertext = server_key.create_trivial_radix(2_u64, (nb_blocks).into()); + dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy); + let mut dummy_blocks = dummy.clone().into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + dummy = RadixCiphertext::from_blocks(dummy_blocks[0..(nb_blocks as usize >> 1)].to_vec()); + let sqrt_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| x.sqrt()); + let div_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| x / (num_iter as u64)); + println!( + "LUT generation done in {:?} sec.", + lut_gen_start.elapsed().as_secs_f64() + ); + + assert!( + num_iter < data.len(), + "Not enough columns in CSV for that many iterations" + ); + + let bench_start = Instant::now(); + let sum_dists = (1..num_iter + 1) + .into_par_iter() + .map(|i| { + let ys = &data[i]; + // Compute the encrypted euclidean distance + + let start = Instant::now(); + println!("{}) Starting computing Squared Euclidean distance", i); + + let euclid_squared_enc = xs_enc + .iter() + .zip(ys.iter()) + .map(|(x_enc, &y)| { + let diff = server_key.scalar_sub_parallelized(x_enc, y); + server_key.mul_parallelized(&diff, &diff) + }) + .fold( + server_key.create_trivial_radix(0_u64, nb_blocks), + |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), + ); + println!( + "{}) Finished computing Squared Euclidean distance in {:?} sec.", + i, + start.elapsed().as_secs_f64() + ); + // println!("euclid_squared_enc degree: {:?}", euclid_squared_enc.blocks()[0].degree); + println!("{}) Starting computing square root", i); + let distance_enc = ct_lut_eval_quantized_no_gen( + euclid_squared_enc, + nb_blocks, + &wopbs_key, + &server_key, + &sqrt_lut, + ); + println!( + "{}) Finished computing square root in {:?} sec.", + i, + start.elapsed().as_secs_f64() + ); + + distance_enc + }) + .collect::>() + .into_iter() + .fold( + server_key.create_trivial_radix(0_u64, nb_blocks), + |acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff), + ); + + // println!("sum_dists degree: {:?}", sum_dists.blocks()[0].degree); + let dists_mean_enc = + ct_lut_eval_quantized_no_gen(sum_dists, nb_blocks, &wopbs_key, &server_key, &div_lut); + println!( + "Finished computing everything in {:?} sec.", + bench_start.elapsed().as_secs_f64() + ); + // ------- Client side ------- // + let mean_distance: u64 = client_key.decrypt(&dists_mean_enc); + println!( + "Mean of {} Euclidean distances: {}", + num_iter, mean_distance + ); +} diff --git a/src/lr_lut.rs b/src/lr_lut.rs index 167f70d..c393ddd 100644 --- a/src/lr_lut.rs +++ b/src/lr_lut.rs @@ -124,9 +124,8 @@ fn main() { let ct_prod = server_key.scalar_mul_parallelized(s, weight); prediction = server_key.add_parallelized(&ct_prod, &prediction); } - prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction); - let activation = wopbs_key.wopbs(&prediction, &sigmoid_lut); - let probability = wopbs_key.keyswitch_to_pbs_params(&activation); + let probability = + ct_lut_eval_no_gen(prediction, &wopbs_key, &server_key, &sigmoid_lut); println!( "Finished inference #{:?} in {:?} sec.", @@ -145,9 +144,7 @@ fn main() { let ct_prod = server_key.scalar_mul_parallelized(s, weight); prediction = server_key.add_parallelized(&ct_prod, &prediction); } - prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction); - let activation = wopbs_key.wopbs(&prediction, &sigmoid_lut); - let probability = wopbs_key.keyswitch_to_pbs_params(&activation); + let probability = ct_lut_eval_no_gen(prediction, &wopbs_key, &server_key, &sigmoid_lut); println!( "Finished inference in {:?} sec.", diff --git a/src/lr_lut_quantized.rs b/src/lr_lut_quantized.rs new file mode 100644 index 0000000..eb3d326 --- /dev/null +++ b/src/lr_lut_quantized.rs @@ -0,0 +1,184 @@ +use std::time::Instant; + +use clap::{App, Arg}; +use rayon::prelude::*; +use ripple::common::*; +use tfhe::{ + integer::{ + gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext, + }, + shortint::parameters::{ + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + }, +}; + +fn main() { + println!("Encrypted Logistic Regression"); + + let matches = App::new("Ripple") + .about("Vanilla Encrypted Logistic Regression") + .arg( + Arg::new("num-samples") + .long("num-samples") + .short('n') + .takes_value(true) + .value_name("INT") + .help("Number of samples") + .default_value("1") + .required(false), + ) + .get_matches(); + + let num_samples = matches + .value_of("num-samples") + .unwrap_or("1") + .parse::() + .expect("Number of samples must be an integer"); + + // ------- Client side ------- // + let bit_width = 24; + let precision = 8; + assert!(precision <= bit_width / 2); + + // Number of blocks per ciphertext + let nb_blocks = bit_width / 2; + println!( + "Number of blocks for the radix decomposition: {:?}", + nb_blocks + ); + + let start = Instant::now(); + // Generate radix keys + let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into()); + // Generate key for PBS (without padding) + let wopbs_key = WopbsKey::new_wopbs_key( + &client_key, + &server_key, + &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + ); + println!( + "Key generation done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + let (weights, bias) = load_weights_and_biases(); + let (weights_int, bias_int) = quantize_weights_and_bias(&weights, bias, precision, bit_width); + let (dataset, targets) = prepare_penguins_dataset(); + + let start = Instant::now(); + let mut encrypted_dataset: Vec> = dataset + .par_iter() // Use par_iter() for parallel iteration + .map(|sample| { + sample + .par_iter() + .map(|&s| { + let quantized = quantize(s, precision, bit_width); + client_key.encrypt(quantized) + }) + .collect() + }) + .collect(); + println!( + "Encryption done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + // ------- Server side ------- // + + // Build LUT for Sigmoid -- Offline cost + let lut_gen_start = Instant::now(); + println!("Generating LUT."); + let mut dummy: RadixCiphertext = server_key.create_trivial_radix(2_u64, (nb_blocks).into()); + for _ in 0..weights_int.len() { + let dummy_2 = server_key.scalar_mul_parallelized(&dummy, 2_u64); + dummy = server_key.add_parallelized(&dummy_2, &dummy); + } + dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy); + let mut dummy_blocks = dummy.clone().into_blocks().to_vec(); + for block in &mut dummy_blocks { + block.degree = Degree::new(3); + } + dummy = RadixCiphertext::from_blocks(dummy_blocks[0..(nb_blocks as usize >> 1)].to_vec()); + let sigmoid_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| { + sigmoid(x, 2 * precision, precision, bit_width) + }); + println!( + "LUT generation done in {:?} sec.", + lut_gen_start.elapsed().as_secs_f64() + ); + + // Inference + assert!(num_samples <= encrypted_dataset.len()); + let all_probabilities = if num_samples > 1 { + encrypted_dataset + .par_iter_mut() + .enumerate() + .take(num_samples) + .map(|(cnt, sample)| { + let start = Instant::now(); + println!("Starting inference #{:?}.", cnt); + + let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into()); + for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) { + let ct_prod = server_key.scalar_mul_parallelized(s, weight); + prediction = server_key.add_parallelized(&ct_prod, &prediction); + } + let probability = ct_lut_eval_quantized_no_gen( + prediction, + nb_blocks as usize, + &wopbs_key, + &server_key, + &sigmoid_lut, + ); + + println!( + "Finished inference #{:?} in {:?} sec.", + cnt, + start.elapsed().as_secs_f64() + ); + probability + }) + .collect::>() + } else { + let start = Instant::now(); + println!("Starting inference."); + + let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into()); + for (s, &weight) in encrypted_dataset[0].iter_mut().zip(weights_int.iter()) { + let ct_prod = server_key.scalar_mul_parallelized(s, weight); + prediction = server_key.add_parallelized(&ct_prod, &prediction); + } + let probability = ct_lut_eval_quantized_no_gen( + prediction, + nb_blocks as usize, + &wopbs_key, + &server_key, + &sigmoid_lut, + ); + + println!( + "Finished inference in {:?} sec.", + start.elapsed().as_secs_f64() + ); + vec![probability] + }; + + // ------- Client side ------- // + let mut total = 0; + for (num, (target, probability)) in targets.iter().zip(all_probabilities.iter()).enumerate() { + let ptxt_probability: u64 = client_key.decrypt(probability); + let pr = (ptxt_probability as f64) / ((1 << precision) as f64); + + let class = (ptxt_probability > quantize(0.5, precision, bit_width)) as usize; + println!( + "[{}] predicted {:?}, target {:?} (prediction probability {:?})", + num, class, target, pr + ); + if class == *target { + total += 1; + } + } + let accuracy = (total as f32 / num_samples as f32) * 100.0; + println!("Accuracy {accuracy}%"); +} diff --git a/src/lr_ptxt.rs b/src/lr_ptxt.rs index dcb3a37..f6c280f 100644 --- a/src/lr_ptxt.rs +++ b/src/lr_ptxt.rs @@ -29,13 +29,9 @@ fn main() { // Server computation let mut prediction = bias_int; for (&s, &w) in sample.iter().zip(weights_int.iter()) { - println!("s: {:?}", s); - println!("weight: {:?}", w); prediction = add(prediction, mul(w, s, bit_width), bit_width); - println!("MAC result: {:?}", prediction); } let probability = sigmoid(prediction, 2 * precision, precision, bit_width); - println!("probability {probability}"); let class = (probability > quantize(0.5, precision, bit_width)) as usize; // Client computation @@ -51,5 +47,4 @@ fn main() { let accuracy = (total as f64 / dataset.len() as f64) * 100.0; println!("Accuracy {accuracy}%"); - println!("precision: {precision}, bit_width: {bit_width}"); }