Skip to content

Commit

Permalink
Add LUT variants without encrypted LUT generation
Browse files Browse the repository at this point in the history
  • Loading branch information
cgouert committed May 25, 2024
1 parent 4ef548b commit 73154c5
Showing 1 changed file with 211 additions and 0 deletions.
211 changes: 211 additions & 0 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,18 @@ pub fn ct_lut_eval(
(lut_ct, start.elapsed().as_secs_f64())
}

pub fn ct_lut_eval_no_gen(
ct: RadixCiphertext,
wopbs_key: &WopbsKey,
server_key: &ServerKey,
func_lut: &IntegerWopbsLUT,
) -> RadixCiphertext {
let ct_ks = wopbs_key.keyswitch_to_wopbs_params(server_key, &ct);
let mut lut_ct = wopbs_key.wopbs(&ct_ks, &func_lut);
lut_ct = wopbs_key.keyswitch_to_pbs_params(&lut_ct);
lut_ct
}

pub fn ct_lut_eval_quantized(
ct: RadixCiphertext,
precision: u8,
Expand All @@ -297,6 +309,20 @@ pub fn ct_lut_eval_quantized(
)
}

pub fn ct_lut_eval_quantized_no_gen(
ct: RadixCiphertext,
nb_blocks: usize,
wopbs_key: &WopbsKey,
server_key: &ServerKey,
quantized_lut: &IntegerWopbsLUT,
) -> RadixCiphertext {
let quant_blocks = &ct.into_blocks()[(nb_blocks >> 1)..nb_blocks];
let quantized_ct = RadixCiphertext::from_blocks(quant_blocks.to_vec());
let quantized_ct = wopbs_key.keyswitch_to_wopbs_params(server_key, &quantized_ct);
let quantized_ct = wopbs_key.wopbs(&quantized_ct, &quantized_lut);
wopbs_key.keyswitch_to_pbs_params(&quantized_ct)
}

pub fn ct_lut_eval_haar(
ct: RadixCiphertext,
precision: u8,
Expand Down Expand Up @@ -346,6 +372,35 @@ pub fn ct_lut_eval_haar(
(haar_ct, start.elapsed().as_secs_f64())
}

pub fn ct_lut_eval_haar_no_gen(
ct: RadixCiphertext,
nb_blocks: usize,
wopbs_key: &WopbsKey,
server_key: &ServerKey,
haar_lsb_lut: &IntegerWopbsLUT,
haar_msb_lut: &IntegerWopbsLUT,
) -> RadixCiphertext {
// Truncate x
let x_truncated_blocks = &ct.into_blocks()[(nb_blocks >> 1)..nb_blocks];
let x_truncated = RadixCiphertext::from_blocks(x_truncated_blocks.to_vec());
let x_truncated_ks = wopbs_key.keyswitch_to_wopbs_params(server_key, &x_truncated);

let (haar_lsb, haar_msb) = rayon::join(
|| {
let haar_lsb = wopbs_key.wopbs(&x_truncated_ks, &haar_lsb_lut);
wopbs_key.keyswitch_to_pbs_params(&haar_lsb)
},
|| {
let haar_msb = wopbs_key.wopbs(&x_truncated_ks, &haar_msb_lut);
wopbs_key.keyswitch_to_pbs_params(&haar_msb)
},
);
let mut lsb_blocks = haar_lsb.into_blocks();
lsb_blocks.extend(haar_msb.into_blocks());
let haar_ct = RadixCiphertext::from_blocks(lsb_blocks.to_vec());
haar_ct
}

pub fn ct_lut_eval_haar_bounded(
ct: RadixCiphertext,
precision: u8,
Expand Down Expand Up @@ -428,6 +483,68 @@ pub fn ct_lut_eval_haar_bounded(
(haar_ct, start.elapsed().as_secs_f64())
}

pub fn ct_lut_eval_haar_bounded_no_gen(
ct: RadixCiphertext,
precision: u8,
bit_width: usize,
integer_size: u32,
nb_blocks: usize,
wopbs_key: &WopbsKey,
server_key: &ServerKey,
is_symmetric: bool,
haar_lsb_lut: &IntegerWopbsLUT,
haar_msb_lut: &IntegerWopbsLUT,
) -> RadixCiphertext {
let ltz = server_key.scalar_right_shift_parallelized(&ct, bit_width - 1);
let sign = server_key.sub_parallelized(
&server_key.create_trivial_radix(1, nb_blocks),
&server_key.scalar_left_shift_parallelized(&ltz, 1),
);
let abs = server_key.mul_parallelized(&sign, &ct);

// Truncate x
let tmp = (precision as usize) + (integer_size as usize);
let x_truncated_blocks = &abs.clone().into_blocks()[(tmp - (bit_width >> 1)) >> 1..tmp >> 1];
let x_truncated = RadixCiphertext::from_blocks(x_truncated_blocks.to_vec());
let x_truncated_ks = wopbs_key.keyswitch_to_wopbs_params(server_key, &x_truncated);

let (haar_lsb, haar_msb) = rayon::join(
|| {
let haar_lsb = wopbs_key.wopbs(&x_truncated_ks, &haar_lsb_lut);
wopbs_key.keyswitch_to_pbs_params(&haar_lsb)
},
|| {
let haar_msb = wopbs_key.wopbs(&x_truncated_ks, &haar_msb_lut);
wopbs_key.keyswitch_to_pbs_params(&haar_msb)
},
);
let mut lsb_blocks = haar_lsb.into_blocks();
lsb_blocks.extend(haar_msb.into_blocks());
let mut haar_ct = RadixCiphertext::from_blocks(lsb_blocks.to_vec());

// For non-symmetric (around zero) functions like Sigmoid.
if !is_symmetric {
// ltz = (msb == 1)
let precision_encoded =
server_key.create_trivial_radix(2_u64.pow(precision as u32), nb_blocks);
let ltz = server_key.mul_parallelized(&precision_encoded, &ltz);

// eval = sign * eval + ltz
let eval = server_key.add_parallelized(&server_key.mul_parallelized(&haar_ct, &sign), &ltz);
let check_value = 2_u64.pow(precision as u32 + integer_size);
let check = server_key.scalar_lt_parallelized(&abs, check_value); // abs < 2^{integer_size + precision}
let check = check.into_radix(nb_blocks, server_key);
// limit = 1 - ltz
let limit = server_key.sub_parallelized(&precision_encoded, &ltz);
// return limit + check * (eval - limit)
haar_ct = server_key.add_parallelized(
&limit,
&server_key.mul_parallelized(&check, &server_key.sub_parallelized(&eval, &limit)),
);
}
haar_ct
}

pub fn ct_lut_eval_bior(
ct: RadixCiphertext,
bit_width: usize,
Expand Down Expand Up @@ -535,3 +652,97 @@ pub fn ct_lut_eval_bior(
let probability = server_key.add_parallelized(&output_1, &output_2);
(probability, start.elapsed().as_secs_f64())
}

pub fn ct_lut_eval_bior_no_gen(
ct: RadixCiphertext,
bit_width: usize,
nb_blocks: usize,
wave_depth: usize,
wopbs_key: &WopbsKey,
offset: i32,
server_key: &ServerKey,
encoded_luts: &Vec<IntegerWopbsLUT>,
) -> RadixCiphertext {
let nb_blocks_lsb = (bit_width - wave_depth) >> 1;
// Split into wave_depth MSBs and n - wave_depth LSBs
let ct_blocks = &ct.into_blocks();
let (lsb, msb) = rayon::join(
|| {
let prediction_blocks_lsb = &ct_blocks[0..nb_blocks_lsb];
RadixCiphertext::from_blocks(prediction_blocks_lsb.to_vec())
},
|| {
let prediction_blocks_msb = &ct_blocks[nb_blocks_lsb..nb_blocks];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks_msb.to_vec());
wopbs_key.keyswitch_to_wopbs_params(server_key, &prediction_msb)
},
);
let (output_1, output_2) = rayon::join(
|| {
// Eval LUT over MSBs
let lut_lsb = wopbs_key.wopbs(&msb, &encoded_luts[0]);
let mut lut_lsb_blocks = wopbs_key.keyswitch_to_pbs_params(&lut_lsb).into_blocks();
// Eval additional LUT if output bit-width is greater than
// wave_depth bits
if encoded_luts.len() > 2 {
let lut_msb = wopbs_key.wopbs(&msb, &encoded_luts[2]);
let lut_msb_blocks = wopbs_key.keyswitch_to_pbs_params(&lut_msb).into_blocks();
lut_lsb_blocks.extend(lut_msb_blocks);
}
// Pad LUT output and LSB by 6 bits to avoid overflows
let padding_ct_block = server_key
.create_trivial_zero_radix::<RadixCiphertext>(3)
.into_blocks();
lut_lsb_blocks.extend(padding_ct_block.clone());
let mut lsb_blocks = lsb.clone().into_blocks();
lsb_blocks.extend(padding_ct_block);
let mut lut_combined = RadixCiphertext::from_blocks(lut_lsb_blocks);
let lsb_extended = RadixCiphertext::from_blocks(lsb_blocks);

// subtract offset (if necessary)
if (offset.abs() as u64) > 0 {
lut_combined =
server_key.scalar_sub_parallelized(&lut_combined, offset.abs() as u64);
}

// l1 = 2^J - lsb
let scalar_l1: RadixCiphertext =
server_key.create_trivial_radix(2u64.pow(wave_depth as u32), nb_blocks_lsb + 1);
let scalar_l1 = server_key.sub_parallelized(&scalar_l1, &lsb_extended);

// Multiply l1 by LUT output
server_key.mul_parallelized(&lut_combined, &scalar_l1)
},
|| {
// Eval LUT over MSBs
let lut_lsb = wopbs_key.wopbs(&msb, &encoded_luts[1]);
let mut lut_lsb_blocks = wopbs_key.keyswitch_to_pbs_params(&lut_lsb).into_blocks();
// Eval additional LUT if output bit-width is greater than
// wave_depth bits
if encoded_luts.len() > 2 {
let lut_msb = wopbs_key.wopbs(&msb, &encoded_luts[3]);
let lut_msb_blocks = wopbs_key.keyswitch_to_pbs_params(&lut_msb).into_blocks();
lut_lsb_blocks.extend(lut_msb_blocks);
}
// Pad LUT output and LSB by 6 bits to avoid overflows
let padding_ct_block = server_key
.create_trivial_zero_radix::<RadixCiphertext>(3)
.into_blocks();
lut_lsb_blocks.extend(padding_ct_block.clone());
let mut lsb_blocks = lsb.clone().into_blocks();
lsb_blocks.extend(padding_ct_block);
let mut lut_combined = RadixCiphertext::from_blocks(lut_lsb_blocks);
let lsb_extended = RadixCiphertext::from_blocks(lsb_blocks);

// subtract offset (if necessary)
if (offset.abs() as u64) > 0 {
lut_combined =
server_key.scalar_sub_parallelized(&lut_combined, offset.abs() as u64);
}
// l2 = lsb
// Multiply MSBs and LSBs
server_key.mul_parallelized(&lut_combined, &lsb_extended)
},
);
server_key.add_parallelized(&output_1, &output_2)
}

0 comments on commit 73154c5

Please sign in to comment.