diff --git a/src/common.rs b/src/common.rs index aa67224..747ed9b 100644 --- a/src/common.rs +++ b/src/common.rs @@ -271,6 +271,18 @@ pub fn ct_lut_eval( (lut_ct, start.elapsed().as_secs_f64()) } +pub fn ct_lut_eval_no_gen( + ct: RadixCiphertext, + wopbs_key: &WopbsKey, + server_key: &ServerKey, + func_lut: &IntegerWopbsLUT, +) -> RadixCiphertext { + let ct_ks = wopbs_key.keyswitch_to_wopbs_params(server_key, &ct); + let mut lut_ct = wopbs_key.wopbs(&ct_ks, &func_lut); + lut_ct = wopbs_key.keyswitch_to_pbs_params(&lut_ct); + lut_ct +} + pub fn ct_lut_eval_quantized( ct: RadixCiphertext, precision: u8, @@ -297,6 +309,20 @@ pub fn ct_lut_eval_quantized( ) } +pub fn ct_lut_eval_quantized_no_gen( + ct: RadixCiphertext, + nb_blocks: usize, + wopbs_key: &WopbsKey, + server_key: &ServerKey, + quantized_lut: &IntegerWopbsLUT, +) -> RadixCiphertext { + let quant_blocks = &ct.into_blocks()[(nb_blocks >> 1)..nb_blocks]; + let quantized_ct = RadixCiphertext::from_blocks(quant_blocks.to_vec()); + let quantized_ct = wopbs_key.keyswitch_to_wopbs_params(server_key, &quantized_ct); + let quantized_ct = wopbs_key.wopbs(&quantized_ct, &quantized_lut); + wopbs_key.keyswitch_to_pbs_params(&quantized_ct) +} + pub fn ct_lut_eval_haar( ct: RadixCiphertext, precision: u8, @@ -346,6 +372,35 @@ pub fn ct_lut_eval_haar( (haar_ct, start.elapsed().as_secs_f64()) } +pub fn ct_lut_eval_haar_no_gen( + ct: RadixCiphertext, + nb_blocks: usize, + wopbs_key: &WopbsKey, + server_key: &ServerKey, + haar_lsb_lut: &IntegerWopbsLUT, + haar_msb_lut: &IntegerWopbsLUT, +) -> RadixCiphertext { + // Truncate x + let x_truncated_blocks = &ct.into_blocks()[(nb_blocks >> 1)..nb_blocks]; + let x_truncated = RadixCiphertext::from_blocks(x_truncated_blocks.to_vec()); + let x_truncated_ks = wopbs_key.keyswitch_to_wopbs_params(server_key, &x_truncated); + + let (haar_lsb, haar_msb) = rayon::join( + || { + let haar_lsb = wopbs_key.wopbs(&x_truncated_ks, &haar_lsb_lut); + wopbs_key.keyswitch_to_pbs_params(&haar_lsb) + }, + || { + let haar_msb = wopbs_key.wopbs(&x_truncated_ks, &haar_msb_lut); + wopbs_key.keyswitch_to_pbs_params(&haar_msb) + }, + ); + let mut lsb_blocks = haar_lsb.into_blocks(); + lsb_blocks.extend(haar_msb.into_blocks()); + let haar_ct = RadixCiphertext::from_blocks(lsb_blocks.to_vec()); + haar_ct +} + pub fn ct_lut_eval_haar_bounded( ct: RadixCiphertext, precision: u8, @@ -428,6 +483,68 @@ pub fn ct_lut_eval_haar_bounded( (haar_ct, start.elapsed().as_secs_f64()) } +pub fn ct_lut_eval_haar_bounded_no_gen( + ct: RadixCiphertext, + precision: u8, + bit_width: usize, + integer_size: u32, + nb_blocks: usize, + wopbs_key: &WopbsKey, + server_key: &ServerKey, + is_symmetric: bool, + haar_lsb_lut: &IntegerWopbsLUT, + haar_msb_lut: &IntegerWopbsLUT, +) -> RadixCiphertext { + let ltz = server_key.scalar_right_shift_parallelized(&ct, bit_width - 1); + let sign = server_key.sub_parallelized( + &server_key.create_trivial_radix(1, nb_blocks), + &server_key.scalar_left_shift_parallelized(<z, 1), + ); + let abs = server_key.mul_parallelized(&sign, &ct); + + // Truncate x + let tmp = (precision as usize) + (integer_size as usize); + let x_truncated_blocks = &abs.clone().into_blocks()[(tmp - (bit_width >> 1)) >> 1..tmp >> 1]; + let x_truncated = RadixCiphertext::from_blocks(x_truncated_blocks.to_vec()); + let x_truncated_ks = wopbs_key.keyswitch_to_wopbs_params(server_key, &x_truncated); + + let (haar_lsb, haar_msb) = rayon::join( + || { + let haar_lsb = wopbs_key.wopbs(&x_truncated_ks, &haar_lsb_lut); + wopbs_key.keyswitch_to_pbs_params(&haar_lsb) + }, + || { + let haar_msb = wopbs_key.wopbs(&x_truncated_ks, &haar_msb_lut); + wopbs_key.keyswitch_to_pbs_params(&haar_msb) + }, + ); + let mut lsb_blocks = haar_lsb.into_blocks(); + lsb_blocks.extend(haar_msb.into_blocks()); + let mut haar_ct = RadixCiphertext::from_blocks(lsb_blocks.to_vec()); + + // For non-symmetric (around zero) functions like Sigmoid. + if !is_symmetric { + // ltz = (msb == 1) + let precision_encoded = + server_key.create_trivial_radix(2_u64.pow(precision as u32), nb_blocks); + let ltz = server_key.mul_parallelized(&precision_encoded, <z); + + // eval = sign * eval + ltz + let eval = server_key.add_parallelized(&server_key.mul_parallelized(&haar_ct, &sign), <z); + let check_value = 2_u64.pow(precision as u32 + integer_size); + let check = server_key.scalar_lt_parallelized(&abs, check_value); // abs < 2^{integer_size + precision} + let check = check.into_radix(nb_blocks, server_key); + // limit = 1 - ltz + let limit = server_key.sub_parallelized(&precision_encoded, <z); + // return limit + check * (eval - limit) + haar_ct = server_key.add_parallelized( + &limit, + &server_key.mul_parallelized(&check, &server_key.sub_parallelized(&eval, &limit)), + ); + } + haar_ct +} + pub fn ct_lut_eval_bior( ct: RadixCiphertext, bit_width: usize, @@ -535,3 +652,97 @@ pub fn ct_lut_eval_bior( let probability = server_key.add_parallelized(&output_1, &output_2); (probability, start.elapsed().as_secs_f64()) } + +pub fn ct_lut_eval_bior_no_gen( + ct: RadixCiphertext, + bit_width: usize, + nb_blocks: usize, + wave_depth: usize, + wopbs_key: &WopbsKey, + offset: i32, + server_key: &ServerKey, + encoded_luts: &Vec, +) -> RadixCiphertext { + let nb_blocks_lsb = (bit_width - wave_depth) >> 1; + // Split into wave_depth MSBs and n - wave_depth LSBs + let ct_blocks = &ct.into_blocks(); + let (lsb, msb) = rayon::join( + || { + let prediction_blocks_lsb = &ct_blocks[0..nb_blocks_lsb]; + RadixCiphertext::from_blocks(prediction_blocks_lsb.to_vec()) + }, + || { + let prediction_blocks_msb = &ct_blocks[nb_blocks_lsb..nb_blocks]; + let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks_msb.to_vec()); + wopbs_key.keyswitch_to_wopbs_params(server_key, &prediction_msb) + }, + ); + let (output_1, output_2) = rayon::join( + || { + // Eval LUT over MSBs + let lut_lsb = wopbs_key.wopbs(&msb, &encoded_luts[0]); + let mut lut_lsb_blocks = wopbs_key.keyswitch_to_pbs_params(&lut_lsb).into_blocks(); + // Eval additional LUT if output bit-width is greater than + // wave_depth bits + if encoded_luts.len() > 2 { + let lut_msb = wopbs_key.wopbs(&msb, &encoded_luts[2]); + let lut_msb_blocks = wopbs_key.keyswitch_to_pbs_params(&lut_msb).into_blocks(); + lut_lsb_blocks.extend(lut_msb_blocks); + } + // Pad LUT output and LSB by 6 bits to avoid overflows + let padding_ct_block = server_key + .create_trivial_zero_radix::(3) + .into_blocks(); + lut_lsb_blocks.extend(padding_ct_block.clone()); + let mut lsb_blocks = lsb.clone().into_blocks(); + lsb_blocks.extend(padding_ct_block); + let mut lut_combined = RadixCiphertext::from_blocks(lut_lsb_blocks); + let lsb_extended = RadixCiphertext::from_blocks(lsb_blocks); + + // subtract offset (if necessary) + if (offset.abs() as u64) > 0 { + lut_combined = + server_key.scalar_sub_parallelized(&lut_combined, offset.abs() as u64); + } + + // l1 = 2^J - lsb + let scalar_l1: RadixCiphertext = + server_key.create_trivial_radix(2u64.pow(wave_depth as u32), nb_blocks_lsb + 1); + let scalar_l1 = server_key.sub_parallelized(&scalar_l1, &lsb_extended); + + // Multiply l1 by LUT output + server_key.mul_parallelized(&lut_combined, &scalar_l1) + }, + || { + // Eval LUT over MSBs + let lut_lsb = wopbs_key.wopbs(&msb, &encoded_luts[1]); + let mut lut_lsb_blocks = wopbs_key.keyswitch_to_pbs_params(&lut_lsb).into_blocks(); + // Eval additional LUT if output bit-width is greater than + // wave_depth bits + if encoded_luts.len() > 2 { + let lut_msb = wopbs_key.wopbs(&msb, &encoded_luts[3]); + let lut_msb_blocks = wopbs_key.keyswitch_to_pbs_params(&lut_msb).into_blocks(); + lut_lsb_blocks.extend(lut_msb_blocks); + } + // Pad LUT output and LSB by 6 bits to avoid overflows + let padding_ct_block = server_key + .create_trivial_zero_radix::(3) + .into_blocks(); + lut_lsb_blocks.extend(padding_ct_block.clone()); + let mut lsb_blocks = lsb.clone().into_blocks(); + lsb_blocks.extend(padding_ct_block); + let mut lut_combined = RadixCiphertext::from_blocks(lut_lsb_blocks); + let lsb_extended = RadixCiphertext::from_blocks(lsb_blocks); + + // subtract offset (if necessary) + if (offset.abs() as u64) > 0 { + lut_combined = + server_key.scalar_sub_parallelized(&lut_combined, offset.abs() as u64); + } + // l2 = lsb + // Multiply MSBs and LSBs + server_key.mul_parallelized(&lut_combined, &lsb_extended) + }, + ); + server_key.add_parallelized(&output_1, &output_2) +}