Skip to content

Commit

Permalink
Add all variants for LR and Euclidean distance
Browse files Browse the repository at this point in the history
  • Loading branch information
cgouert committed May 25, 2024
1 parent 73154c5 commit 5cf803d
Show file tree
Hide file tree
Showing 11 changed files with 738 additions and 39 deletions.
16 changes: 16 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ path = "src/lr_ptxt_float.rs"
name = "lr_lut"
path = "src/lr_lut.rs"

[[bin]]
name = "lr_lut_quantized"
path = "src/lr_lut_quantized.rs"

[[bin]]
name = "lr_haar_ptxt"
path = "src/lr_haar_ptxt.rs"
Expand Down Expand Up @@ -79,6 +83,18 @@ path = "src/correlation_lut.rs"
name = "euclidean"
path = "src/euclidean.rs"

[[bin]]
name = "euclidean_quantized"
path = "src/euclidean_quantized.rs"

[[bin]]
name = "euclidean_haar"
path = "src/euclidean_haar.rs"

[[bin]]
name = "euclidean_bior"
path = "src/euclidean_bior.rs"

# Primitive Timings

[[bin]]
Expand Down
1 change: 1 addition & 0 deletions data/bior_lut_div_16.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": 0, "1": 5, "2": 10, "3": 16, "4": 21, "5": 26, "6": 32, "7": 37, "8": 42, "9": 48, "10": 53, "11": 58, "12": 64, "13": 69, "14": 74, "15": 80, "16": 85, "17": 90, "18": 96, "19": 101, "20": 106, "21": 112, "22": 117, "23": 122, "24": 128, "25": 133, "26": 138, "27": 144, "28": 149, "29": 154, "30": 160, "31": 165, "32": 170, "33": 176, "34": 181, "35": 186, "36": 192, "37": 197, "38": 202, "39": 208, "40": 213, "41": 218, "42": 224, "43": 229, "44": 234, "45": 240, "46": 245, "47": 250, "48": 0, "49": 5, "50": 10, "51": 16, "52": 21, "53": 26, "54": 32, "55": 37, "56": 42, "57": 48, "58": 53, "59": 58, "60": 64, "61": 69, "62": 74, "63": 80, "64": 85, "65": 90, "66": 96, "67": 101, "68": 106, "69": 112, "70": 117, "71": 122, "72": 128, "73": 133, "74": 138, "75": 144, "76": 149, "77": 154, "78": 160, "79": 165, "80": 170, "81": 176, "82": 181, "83": 186, "84": 192, "85": 197, "86": 202, "87": 208, "88": 213, "89": 218, "90": 224, "91": 229, "92": 234, "93": 240, "94": 245, "95": 250, "96": 0, "97": 5, "98": 10, "99": 16, "100": 21, "101": 26, "102": 32, "103": 37, "104": 42, "105": 48, "106": 53, "107": 58, "108": 64, "109": 69, "110": 74, "111": 80, "112": 85, "113": 90, "114": 96, "115": 101, "116": 106, "117": 112, "118": 117, "119": 122, "120": 128, "121": 133, "122": 138, "123": 144, "124": 149, "125": 154, "126": 160, "127": 165, "128": 170, "129": 176, "130": 181, "131": 186, "132": 192, "133": 197, "134": 202, "135": 208, "136": 213, "137": 218, "138": 224, "139": 229, "140": 234, "141": 240, "142": 245, "143": 250, "144": 0, "145": 5, "146": 10, "147": 16, "148": 21, "149": 26, "150": 32, "151": 37, "152": 42, "153": 48, "154": 53, "155": 58, "156": 64, "157": 69, "158": 74, "159": 80, "160": 85, "161": 90, "162": 96, "163": 101, "164": 106, "165": 112, "166": 117, "167": 122, "168": 128, "169": 133, "170": 138, "171": 144, "172": 149, "173": 154, "174": 160, "175": 165, "176": 170, "177": 176, "178": 181, "179": 186, "180": 192, "181": 197, "182": 202, "183": 208, "184": 213, "185": 218, "186": 224, "187": 229, "188": 234, "189": 240, "190": 245, "191": 250, "192": 0, "193": 5, "194": 10, "195": 16, "196": 21, "197": 26, "198": 32, "199": 37, "200": 42, "201": 48, "202": 53, "203": 58, "204": 64, "205": 69, "206": 74, "207": 80, "208": 85, "209": 90, "210": 96, "211": 101, "212": 106, "213": 112, "214": 117, "215": 122, "216": 128, "217": 133, "218": 138, "219": 144, "220": 149, "221": 154, "222": 160, "223": 165, "224": 170, "225": 176, "226": 181, "227": 186, "228": 192, "229": 197, "230": 202, "231": 208, "232": 213, "233": 218, "234": 224, "235": 229, "236": 234, "237": 240, "238": 245, "239": 250, "240": 0, "241": 5, "242": 10, "243": 16, "244": 21, "245": 26, "246": 32, "247": 37, "248": 42, "249": 48, "250": 53, "251": 58, "252": 64, "253": 69, "254": 74, "255": 80}
1 change: 1 addition & 0 deletions data/bior_lut_div_16_2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"0": 5, "1": 10, "2": 16, "3": 21, "4": 26, "5": 32, "6": 37, "7": 42, "8": 48, "9": 53, "10": 58, "11": 64, "12": 69, "13": 74, "14": 80, "15": 85, "16": 90, "17": 96, "18": 101, "19": 106, "20": 112, "21": 117, "22": 122, "23": 128, "24": 133, "25": 138, "26": 144, "27": 149, "28": 154, "29": 160, "30": 165, "31": 170, "32": 176, "33": 181, "34": 186, "35": 192, "36": 197, "37": 202, "38": 208, "39": 213, "40": 218, "41": 224, "42": 229, "43": 234, "44": 240, "45": 245, "46": 250, "47": 0, "48": 5, "49": 10, "50": 16, "51": 21, "52": 26, "53": 32, "54": 37, "55": 42, "56": 48, "57": 53, "58": 58, "59": 64, "60": 69, "61": 74, "62": 80, "63": 85, "64": 90, "65": 96, "66": 101, "67": 106, "68": 112, "69": 117, "70": 122, "71": 128, "72": 133, "73": 138, "74": 144, "75": 149, "76": 154, "77": 160, "78": 165, "79": 170, "80": 176, "81": 181, "82": 186, "83": 192, "84": 197, "85": 202, "86": 208, "87": 213, "88": 218, "89": 224, "90": 229, "91": 234, "92": 240, "93": 245, "94": 250, "95": 0, "96": 5, "97": 10, "98": 16, "99": 21, "100": 26, "101": 32, "102": 37, "103": 42, "104": 48, "105": 53, "106": 58, "107": 64, "108": 69, "109": 74, "110": 80, "111": 85, "112": 90, "113": 96, "114": 101, "115": 106, "116": 112, "117": 117, "118": 122, "119": 128, "120": 133, "121": 138, "122": 144, "123": 149, "124": 154, "125": 160, "126": 165, "127": 170, "128": 176, "129": 181, "130": 186, "131": 192, "132": 197, "133": 202, "134": 208, "135": 213, "136": 218, "137": 224, "138": 229, "139": 234, "140": 240, "141": 245, "142": 250, "143": 0, "144": 5, "145": 10, "146": 16, "147": 21, "148": 26, "149": 32, "150": 37, "151": 42, "152": 48, "153": 53, "154": 58, "155": 64, "156": 69, "157": 74, "158": 80, "159": 85, "160": 90, "161": 96, "162": 101, "163": 106, "164": 112, "165": 117, "166": 122, "167": 128, "168": 133, "169": 138, "170": 144, "171": 149, "172": 154, "173": 160, "174": 165, "175": 170, "176": 176, "177": 181, "178": 186, "179": 192, "180": 197, "181": 202, "182": 208, "183": 213, "184": 218, "185": 224, "186": 229, "187": 234, "188": 240, "189": 245, "190": 250, "191": 0, "192": 5, "193": 10, "194": 16, "195": 21, "196": 26, "197": 32, "198": 37, "199": 42, "200": 48, "201": 53, "202": 58, "203": 64, "204": 69, "205": 74, "206": 80, "207": 85, "208": 90, "209": 96, "210": 101, "211": 106, "212": 112, "213": 117, "214": 122, "215": 128, "216": 133, "217": 138, "218": 144, "219": 149, "220": 154, "221": 160, "222": 165, "223": 170, "224": 176, "225": 181, "226": 186, "227": 192, "228": 197, "229": 202, "230": 208, "231": 213, "232": 218, "233": 224, "234": 229, "235": 234, "236": 240, "237": 245, "238": 250, "239": 0, "240": 5, "241": 10, "242": 16, "243": 21, "244": 26, "245": 32, "246": 37, "247": 42, "248": 48, "249": 53, "250": 58, "251": 64, "252": 69, "253": 74, "254": 80, "255": 85}
2 changes: 1 addition & 1 deletion src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ pub fn quantized_table(
(lsb, msb)
}

fn eval_lut(x: u64, lut_map: &Vec<u64>) -> u64 {
pub fn eval_lut(x: u64, lut_map: &Vec<u64>) -> u64 {
lut_map[x as usize]
}

Expand Down
69 changes: 42 additions & 27 deletions src/euclidean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,30 @@ use std::time::Instant;

use num_integer::Roots;
use rayon::prelude::*;
use ripple::common;
use ripple::common::{self, ct_lut_eval_no_gen};
use tfhe::{
integer::{gen_keys_radix, wopbs::*, RadixCiphertext},
integer::{
gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext,
},
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
},
};

/// d(x, y) = sqrt( sum((xi - yi)^2) )
fn euclidean(x: &[u32], y: &[u32]) -> f32 {
x.iter()
.zip(y.iter())
.map(|(&xi, &yi)| (xi - yi).pow(2) as f32)
.sum::<f32>()
.sqrt()
}
// fn euclidean(x: &[u32], y: &[u32]) -> f32 {
// x.iter()
// .zip(y.iter())
// .map(|(&xi, &yi)| (xi - yi).pow(2) as f32)
// .sum::<f32>()
// .sqrt()
// }

fn main() {
let data = common::read_csv("data/euclidean.csv");
let xs = &data[0];
let num_iter = 3;

// ------- Client side ------- //
let bit_width = 16;
Expand Down Expand Up @@ -59,28 +62,42 @@ fn main() {
);

// ------- Server side ------- //
// TODO: Move LUT gens up here
let lut_gen_start = Instant::now();
println!("Generating LUT.");
let mut dummy: RadixCiphertext = server_key.create_trivial_radix(2_u64, (nb_blocks).into());
dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy);
let mut dummy_blocks = dummy.clone().into_blocks().to_vec();
for block in &mut dummy_blocks {
block.degree = Degree::new(3);
}
dummy = RadixCiphertext::from_blocks(dummy_blocks);
let sqrt_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| x.sqrt());
let div_lut = wopbs_key.generate_lut_radix(&dummy, |x: u64| x / (num_iter as u64));
println!(
"LUT generation done in {:?} sec.",
lut_gen_start.elapsed().as_secs_f64()
);

let num_iter = 3;
assert!(
num_iter < data.len(),
"Not enough columns in CSV for that many iterations"
);

let mut sum_dists = (1..num_iter + 1)
let bench_start = Instant::now();
let sum_dists = (1..num_iter + 1)
.into_par_iter()
.map(|i| {
let ys = &data[i];

let distance = euclidean(xs, ys);
println!("{}) Ptxt Euclidean distance: {}", i, distance);
// let distance = euclidean(xs, ys);
// println!("{}) Ptxt Euclidean distance: {}", i, distance);

// Compute the encrypted euclidean distance

let start = Instant::now();
println!("{}) Starting computing Squared Euclidean distance", i);

let mut euclid_squared_enc = xs_enc
let euclid_squared_enc = xs_enc
.iter()
.zip(ys.iter())
.map(|(x_enc, &y)| {
Expand All @@ -96,13 +113,10 @@ fn main() {
i,
start.elapsed().as_secs_f64()
);

// println!("euclid_squared_enc degree: {:?}", euclid_squared_enc.blocks()[0].degree);
println!("{}) Starting computing square root", i);
let sqrt_lut = wopbs_key.generate_lut_radix(&euclid_squared_enc, |x: u64| x.sqrt());
euclid_squared_enc =
wopbs_key.keyswitch_to_wopbs_params(&server_key, &euclid_squared_enc);
let mut distance_enc = wopbs_key.wopbs(&euclid_squared_enc, &sqrt_lut);
distance_enc = wopbs_key.keyswitch_to_pbs_params(&distance_enc);
let distance_enc =
ct_lut_eval_no_gen(euclid_squared_enc, &wopbs_key, &server_key, &sqrt_lut);
println!(
"{}) Finished computing square root in {:?} sec.",
i,
Expand All @@ -118,11 +132,12 @@ fn main() {
|acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff),
);

let div_lut = wopbs_key.generate_lut_radix(&sum_dists, |x: u64| x / (num_iter as u64));
sum_dists = wopbs_key.keyswitch_to_wopbs_params(&server_key, &sum_dists);
let mut dists_mean_enc = wopbs_key.wopbs(&sum_dists, &div_lut);
dists_mean_enc = wopbs_key.keyswitch_to_pbs_params(&dists_mean_enc);

// println!("sum_dists degree: {:?}", sum_dists.blocks()[0].degree);
let dists_mean_enc = ct_lut_eval_no_gen(sum_dists, &wopbs_key, &server_key, &div_lut);
println!(
"Finished computing everything in {:?} sec.",
bench_start.elapsed().as_secs_f64()
);
// ------- Client side ------- //
let mean_distance: u64 = client_key.decrypt(&dists_mean_enc);
println!(
Expand Down
174 changes: 174 additions & 0 deletions src/euclidean_bior.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
use std::time::Instant;

use rayon::prelude::*;
use ripple::common::*;
use tfhe::{
integer::{
gen_keys_radix, wopbs::*, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext,
},
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, Degree,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
},
};

fn main() {
let data = read_csv("data/euclidean.csv");
let xs = &data[0];
let num_iter = 3;

// ------- Client side ------- //
let bit_width = 16;
let precision = 12;
let j = 8;

let (sqrt_lut_lsb, sqrt_lut_msb) = bior("data/bior_lut_sqrt_16.json", j as u8, bit_width);
let (sqrt_lut_lsb_2, sqrt_lut_msb_2) = bior("data/bior_lut_sqrt_16_2.json", j as u8, bit_width);
let (div_lut_lsb, div_lut_msb) = bior("data/bior_lut_div_16.json", j as u8, bit_width);
let (div_lut_lsb_2, div_lut_msb_2) = bior("data/bior_lut_div_16_2.json", j as u8, bit_width);

let sqrt_luts = vec![
&sqrt_lut_lsb,
&sqrt_lut_lsb_2,
&sqrt_lut_msb,
&sqrt_lut_msb_2,
];
let div_luts = vec![&div_lut_lsb, &div_lut_lsb_2, &div_lut_msb, &div_lut_msb_2];

// Number of blocks per ciphertext
let nb_blocks = bit_width / 2;
println!(
"Number of blocks for the radix decomposition: {:?}",
nb_blocks
);

let start = Instant::now();
// Generate radix keys
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into());
// Generate key for PBS (without padding)
let wopbs_key = WopbsKey::new_wopbs_key(
&client_key,
&server_key,
&WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
);
println!(
"Key generation done in {:?} sec.",
start.elapsed().as_secs_f64()
);

let start = Instant::now();
let xs_enc: Vec<_> = xs
.par_iter() // Use par_iter() for parallel iteration
.map(|&x| client_key.encrypt(quantize(x as f64, precision, bit_width as u8)))
.collect();
println!(
"Encryption done in {:?} sec.",
start.elapsed().as_secs_f64()
);

// ------- Server side ------- //
let lut_gen_start = Instant::now();
println!("Generating LUT.");
let dummy: RadixCiphertext = server_key.create_trivial_radix(0_u64, j >> 1);
let mut dummy_blocks = dummy.into_blocks().to_vec();
for block in &mut dummy_blocks {
block.degree = Degree::new(3);
}
let dummy = RadixCiphertext::from_blocks(dummy_blocks);
let dummy = wopbs_key.keyswitch_to_wopbs_params(&server_key, &dummy);
let encoded_sqrt_luts = sqrt_luts
.iter()
.map(|lut| wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &lut.to_vec())))
.collect::<Vec<_>>();
let encoded_div_luts = div_luts
.iter()
.map(|lut| wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &lut.to_vec())))
.collect::<Vec<_>>();
println!(
"LUT generation done in {:?} sec.",
lut_gen_start.elapsed().as_secs_f64()
);

assert!(
num_iter < data.len(),
"Not enough columns in CSV for that many iterations"
);

let bench_start = Instant::now();
let sum_dists = (1..num_iter + 1)
.into_par_iter()
.map(|i| {
let ys = &data[i];
// Compute the encrypted euclidean distance

let start = Instant::now();
println!("{}) Starting computing Squared Euclidean distance", i);

let euclid_squared_enc = xs_enc
.iter()
.zip(ys.iter())
.map(|(x_enc, &y)| {
let diff = server_key.scalar_sub_parallelized(x_enc, y);
server_key.mul_parallelized(&diff, &diff)
})
.fold(
server_key.create_trivial_radix(0_u64, nb_blocks.into()),
|acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff),
);
println!(
"{}) Finished computing Squared Euclidean distance in {:?} sec.",
i,
start.elapsed().as_secs_f64()
);
// println!("euclid_squared_enc degree: {:?}", euclid_squared_enc.blocks()[0].degree);
println!("{}) Starting computing square root", i);
let distance_enc = ct_lut_eval_bior_no_gen(
euclid_squared_enc,
bit_width.into(),
nb_blocks.into(),
j,
&wopbs_key,
0_i32,
&server_key,
&encoded_sqrt_luts,
);
println!(
"{}) Finished computing square root in {:?} sec.",
i,
start.elapsed().as_secs_f64()
);

distance_enc
})
.collect::<Vec<_>>()
.into_iter()
.fold(
server_key.create_trivial_radix(0_u64, nb_blocks.into()),
|acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff),
);

// println!("sum_dists degree: {:?}", sum_dists.blocks()[0].degree);
let dists_mean_enc = ct_lut_eval_bior_no_gen(
sum_dists,
bit_width.into(),
nb_blocks.into(),
j,
&wopbs_key,
0_i32,
&server_key,
&encoded_div_luts,
);
println!(
"Finished computing everything in {:?} sec.",
bench_start.elapsed().as_secs_f64()
);

// ------- Client side ------- //
let mean_distance: u64 = client_key.decrypt(&dists_mean_enc);
let mean_distance: f64 = unquantize(mean_distance, precision, bit_width as u8);

println!(
"Mean of {} Euclidean distances: {}",
num_iter, mean_distance
);
}
Loading

0 comments on commit 5cf803d

Please sign in to comment.