Skip to content

Commit

Permalink
Add proofs for portable compress module
Browse files Browse the repository at this point in the history
  • Loading branch information
mamonet committed Oct 16, 2024
1 parent 5ef4ee0 commit 1d5b291
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,43 @@ let compress_ciphertext_coefficient (coefficient_bits: u8) (fe: u16) =

let compress_message_coefficient (fe: u16) =
let (shifted: i16):i16 = Rust_primitives.mk_i16 1664 -! (cast (fe <: u16) <: i16) in
let _:Prims.unit = assert (v shifted == 1664 - v fe) in
let mask:i16 = shifted >>! Rust_primitives.mk_i32 15 in
let _:Prims.unit =
assert (v mask = v shifted / pow2 15);
assert (if v shifted < 0 then mask = ones else mask = zero)
in
let shifted_to_positive:i16 = mask ^. shifted in
let _:Prims.unit =
logxor_lemma shifted mask;
assert (v shifted < 0 ==> v shifted_to_positive = v (lognot shifted));
neg_equiv_lemma shifted;
assert (v (lognot shifted) = - (v shifted) - 1);
assert (v shifted >= 0 ==> v shifted_to_positive = v (mask `logxor` shifted));
assert (v shifted >= 0 ==> mask = zero);
assert (v shifted >= 0 ==> mask ^. shifted = shifted);
assert (v shifted >= 0 ==> v shifted_to_positive = v shifted);
assert (shifted_to_positive >=. mk_i16 0)
in
let shifted_positive_in_range:i16 = shifted_to_positive -! Rust_primitives.mk_i16 832 in
cast ((shifted_positive_in_range >>! Rust_primitives.mk_i32 15 <: i16) &. Rust_primitives.mk_i16 1
<:
i16)
<:
u8
let _:Prims.unit =
assert (1664 - v fe >= 0 ==> v shifted_positive_in_range == 832 - v fe);
assert (1664 - v fe < 0 ==> v shifted_positive_in_range == - 2497 + v fe)
in
let r0:i16 = shifted_positive_in_range >>! Rust_primitives.mk_i32 15 in
let (r1: i16):i16 = r0 &. Rust_primitives.mk_i16 1 in
let res:u8 = cast (r1 <: i16) <: u8 in
let _:Prims.unit =
assert (v r0 = v shifted_positive_in_range / pow2 15);
assert (if v shifted_positive_in_range < 0 then r0 = ones else r0 = zero);
logand_lemma (mk_i16 1) r0;
assert (if v shifted_positive_in_range < 0 then r1 = mk_i16 1 else r1 = mk_i16 0);
assert ((v fe >= 833 && v fe <= 2496) ==> r1 = mk_i16 1);
assert (v fe < 833 ==> r1 = mk_i16 0);
assert (v fe > 2496 ==> r1 = mk_i16 0);
assert (v res = v r1)
in
res

#push-options "--fuel 0 --ifuel 0 --z3rlimit 2000"

Expand Down Expand Up @@ -167,45 +196,88 @@ let compress_1_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)

#pop-options

#push-options "--z3rlimit 300 --ext context_pruning"

let decompress_ciphertext_coefficient
(v_COEFFICIENT_BITS: i32)
(v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
(a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
=
let v:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector =
let _:Prims.unit =
assert_norm (pow2 1 == 2);
assert_norm (pow2 4 == 16);
assert_norm (pow2 5 == 32);
assert_norm (pow2 10 == 1024);
assert_norm (pow2 11 == 2048)
in
let a:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector =
Rust_primitives.Hax.Folds.fold_range (Rust_primitives.mk_usize 0)
Libcrux_ml_kem.Vector.Traits.v_FIELD_ELEMENTS_IN_VECTOR
(fun v temp_1_ ->
let v:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = v in
let _:usize = temp_1_ in
true)
v
(fun v i ->
let v:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = v in
(fun a i ->
let a:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = a in
let i:usize = i in
(v i < 16 ==>
(forall (j: nat).
(j >= v i /\ j < 16) ==>
v (Seq.index a.f_elements j) >= 0 /\
v (Seq.index a.f_elements j) < pow2 (v v_COEFFICIENT_BITS))) /\
(forall (j: nat).
j < v i ==>
v (Seq.index a.f_elements j) < v Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS))
a
(fun a i ->
let a:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = a in
let i:usize = i in
let _:Prims.unit =
assert (v (a.f_elements.[ i ] <: i16) < pow2 11);
assert (v (a.f_elements.[ i ] <: i16) == v (cast (a.f_elements.[ i ] <: i16) <: i32));
assert (v (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) ==
v (cast (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) <: i32));
assert (v ((cast (a.f_elements.[ i ] <: i16) <: i32) *!
(cast (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) <: i32)) ==
v (cast (a.f_elements.[ i ] <: i16) <: i32) *
v (cast (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) <: i32))
in
let decompressed:i32 =
(cast (v.Libcrux_ml_kem.Vector.Portable.Vector_type.f_elements.[ i ] <: i16) <: i32) *!
(cast (a.Libcrux_ml_kem.Vector.Portable.Vector_type.f_elements.[ i ] <: i16) <: i32) *!
(cast (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) <: i32)
in
let _:Prims.unit =
assert (v (decompressed <<! mk_i32 1) == v decompressed * 2);
assert (v (mk_i32 1 <<! v_COEFFICIENT_BITS) == pow2 (v v_COEFFICIENT_BITS));
assert (v ((decompressed <<! mk_i32 1) +! (mk_i32 1 <<! v_COEFFICIENT_BITS)) ==
v (decompressed <<! mk_i32 1) + v (mk_i32 1 <<! v_COEFFICIENT_BITS))
in
let decompressed:i32 =
(decompressed <<! Rust_primitives.mk_i32 1 <: i32) +!
(Rust_primitives.mk_i32 1 <<! v_COEFFICIENT_BITS <: i32)
in
let _:Prims.unit =
assert (v (v_COEFFICIENT_BITS +! mk_i32 1) == v v_COEFFICIENT_BITS + 1);
assert (v (decompressed >>! (v_COEFFICIENT_BITS +! mk_i32 1 <: i32)) ==
v decompressed / pow2 (v v_COEFFICIENT_BITS + 1))
in
let decompressed:i32 =
decompressed >>! (v_COEFFICIENT_BITS +! Rust_primitives.mk_i32 1 <: i32)
in
let v:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector =
let _:Prims.unit =
assert (v decompressed < v Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS);
assert (v (cast decompressed <: i16) < v Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS)
in
let a:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector =
{
v with
a with
Libcrux_ml_kem.Vector.Portable.Vector_type.f_elements
=
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize v
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize a
.Libcrux_ml_kem.Vector.Portable.Vector_type.f_elements
i
(cast (decompressed <: i32) <: i16)
}
<:
Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector
in
v)
a)
in
v
a

#pop-options
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ val compress_message_coefficient (fe: u16)
fun result ->
let result:u8 = result in
Hax_lib.implies ((Rust_primitives.mk_u16 833 <=. fe <: bool) &&
(fe <=. Rust_primitives.mk_u16 2596 <: bool))
(fe <=. Rust_primitives.mk_u16 2496 <: bool))
(fun temp_0_ ->
let _:Prims.unit = temp_0_ in
result =. Rust_primitives.mk_u8 1 <: bool) &&
Hax_lib.implies (~.((Rust_primitives.mk_u16 833 <=. fe <: bool) &&
(fe <=. Rust_primitives.mk_u16 2596 <: bool))
(fe <=. Rust_primitives.mk_u16 2496 <: bool))
<:
bool)
(fun temp_0_ ->
Expand Down Expand Up @@ -84,7 +84,18 @@ val compress_1_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)

val decompress_ciphertext_coefficient
(v_COEFFICIENT_BITS: i32)
(v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
(a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
: Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector
Prims.l_True
(fun _ -> Prims.l_True)
(requires
(v v_COEFFICIENT_BITS == 4 \/ v v_COEFFICIENT_BITS == 5 \/ v v_COEFFICIENT_BITS == 10 \/
v v_COEFFICIENT_BITS == 11) /\
(forall (i: nat).
i < 16 ==>
v (Seq.index a.f_elements i) >= 0 /\
v (Seq.index a.f_elements i) < pow2 (v v_COEFFICIENT_BITS)))
(ensures
fun result ->
let result:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = result in
forall (i: nat).
i < 16 ==>
v (Seq.index result.f_elements i) < Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS)
81 changes: 67 additions & 14 deletions libcrux-ml-kem/src/vector/portable/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ use crate::vector::FIELD_MODULUS;
/// <https://csrc.nist.gov/pubs/fips/203/ipd>.
#[cfg_attr(hax, hax_lib::requires(fe < (FIELD_MODULUS as u16)))]
#[cfg_attr(hax, hax_lib::ensures(|result|
hax_lib::implies(833 <= fe && fe <= 2596, || result == 1) &&
hax_lib::implies(!(833 <= fe && fe <= 2596), || result == 0)
hax_lib::implies(833 <= fe && fe <= 2496, || result == 1) &&
hax_lib::implies(!(833 <= fe && fe <= 2496), || result == 0)
))]
pub(crate) fn compress_message_coefficient(fe: u16) -> u8 {
// The approach used here is inspired by:
Expand All @@ -35,6 +35,7 @@ pub(crate) fn compress_message_coefficient(fe: u16) -> u8 {
// If 833 <= fe <= 2496,
// then -832 <= shifted <= 831
let shifted: i16 = 1664 - (fe as i16);
hax_lib::fstar!("assert (v $shifted == 1664 - v $fe)");

// If shifted < 0, then
// (shifted >> 15) ^ shifted = flip_bits(shifted) = -shifted - 1, and so
Expand All @@ -44,13 +45,37 @@ pub(crate) fn compress_message_coefficient(fe: u16) -> u8 {
// (shifted >> 15) ^ shifted = shifted, and so
// if 0 <= shifted <= 831 then 0 <= shifted_positive <= 831
let mask = shifted >> 15;
hax_lib::fstar!("assert (v $mask = v $shifted / pow2 15);
assert (if v $shifted < 0 then $mask = ones else $mask = zero)");
let shifted_to_positive = mask ^ shifted;
hax_lib::fstar!("logxor_lemma $shifted $mask;
assert (v $shifted < 0 ==> v $shifted_to_positive = v (lognot $shifted));
neg_equiv_lemma $shifted;
assert (v (lognot $shifted) = -(v $shifted) -1);
assert (v $shifted >= 0 ==> v $shifted_to_positive = v ($mask `logxor` $shifted));
assert (v $shifted >= 0 ==> $mask = zero);
assert (v $shifted >= 0 ==> $mask ^. $shifted = $shifted);
assert (v $shifted >= 0 ==> v $shifted_to_positive = v $shifted);
assert ($shifted_to_positive >=. mk_i16 0)");

let shifted_positive_in_range = shifted_to_positive - 832;
hax_lib::fstar!("assert (1664 - v $fe >= 0 ==> v $shifted_positive_in_range == 832 - v $fe);
assert (1664 - v $fe < 0 ==> v $shifted_positive_in_range == -2497 + v $fe)");

// If x <= 831, then x - 832 <= -1, and so x - 832 < 0, which means
// the most significant bit of shifted_positive_in_range will be 1.
((shifted_positive_in_range >> 15) & 1) as u8
let r0 = shifted_positive_in_range >> 15;
let r1: i16 = r0 & 1;
let res = r1 as u8;
hax_lib::fstar!("assert (v $r0 = v $shifted_positive_in_range / pow2 15);
assert (if v $shifted_positive_in_range < 0 then $r0 = ones else $r0 = zero);
logand_lemma (mk_i16 1) $r0;
assert (if v $shifted_positive_in_range < 0 then $r1 = mk_i16 1 else $r1 = mk_i16 0);
assert ((v $fe >= 833 && v $fe <= 2496) ==> $r1 = mk_i16 1);
assert (v $fe < 833 ==> $r1 = mk_i16 0);
assert (v $fe > 2496 ==> $r1 = mk_i16 0);
assert (v $res = v $r1)");
res
}

#[cfg_attr(hax,
Expand Down Expand Up @@ -147,23 +172,51 @@ pub(crate) fn compress<const COEFFICIENT_BITS: i32>(mut a: PortableVector) -> Po
}

#[inline(always)]
#[hax_lib::fstar::options("--z3rlimit 300 --ext context_pruning")]
#[hax_lib::requires(fstar!("(v $COEFFICIENT_BITS == 4 \\/
v $COEFFICIENT_BITS == 5 \\/
v $COEFFICIENT_BITS == 10 \\/
v $COEFFICIENT_BITS == 11) /\\
(forall (i:nat). i < 16 ==> v (Seq.index ${a}.f_elements i) >= 0 /\\
v (Seq.index ${a}.f_elements i) < pow2 (v $COEFFICIENT_BITS))"))]
#[hax_lib::ensures(|result| fstar!("forall (i:nat). i < 16 ==> v (Seq.index ${result}.f_elements i) < $FIELD_MODULUS"))]
pub(crate) fn decompress_ciphertext_coefficient<const COEFFICIENT_BITS: i32>(
mut v: PortableVector,
mut a: PortableVector,
) -> PortableVector {
// debug_assert!(to_i16_array(v)
// .into_iter()
// .all(|coefficient| coefficient.abs() < 1 << COEFFICIENT_BITS));
hax_lib::fstar!("assert_norm (pow2 1 == 2);
assert_norm (pow2 4 == 16);
assert_norm (pow2 5 == 32);
assert_norm (pow2 10 == 1024);
assert_norm (pow2 11 == 2048)");

for i in 0..FIELD_ELEMENTS_IN_VECTOR {
let mut decompressed = v.elements[i] as i32 * FIELD_MODULUS as i32;
hax_lib::loop_invariant!(|i: usize| { fstar!("(v $i < 16 ==> (forall (j:nat). (j >= v $i /\\ j < 16) ==>
v (Seq.index ${a}.f_elements j) >= 0 /\\ v (Seq.index ${a}.f_elements j) < pow2 (v $COEFFICIENT_BITS))) /\\
(forall (j:nat). j < v $i ==>
v (Seq.index ${a}.f_elements j) < v $FIELD_MODULUS)") });
hax_lib::fstar!("assert (v (${a}.f_elements.[ $i ] <: i16) < pow2 11);
assert (v (${a}.f_elements.[ $i ] <: i16) ==
v (cast (${a}.f_elements.[ $i ] <: i16) <: i32));
assert (v ($FIELD_MODULUS <: i16) ==
v (cast ($FIELD_MODULUS <: i16) <: i32));
assert (v ((cast (${a}.f_elements.[ $i ] <: i16) <: i32) *!
(cast ($FIELD_MODULUS <: i16) <: i32)) ==
v (cast (${a}.f_elements.[ $i ] <: i16) <: i32) *
v (cast ($FIELD_MODULUS <: i16) <: i32))");
let mut decompressed = a.elements[i] as i32 * FIELD_MODULUS as i32;
hax_lib::fstar!("assert (v ($decompressed <<! mk_i32 1) == v $decompressed * 2);
assert (v (mk_i32 1 <<! $COEFFICIENT_BITS) == pow2 (v $COEFFICIENT_BITS));
assert (v (($decompressed <<! mk_i32 1) +! (mk_i32 1 <<! $COEFFICIENT_BITS)) ==
v ($decompressed <<! mk_i32 1) + v (mk_i32 1 <<! $COEFFICIENT_BITS))");
decompressed = (decompressed << 1) + (1i32 << COEFFICIENT_BITS);
hax_lib::fstar!("assert (v ($COEFFICIENT_BITS +! mk_i32 1) == v $COEFFICIENT_BITS + 1);
assert (v ($decompressed >>! ($COEFFICIENT_BITS +! mk_i32 1 <: i32)) ==
v $decompressed / pow2 (v $COEFFICIENT_BITS + 1))");
decompressed = decompressed >> (COEFFICIENT_BITS + 1);
v.elements[i] = decompressed as i16;
hax_lib::fstar!("assert (v $decompressed < v $FIELD_MODULUS);
assert (v (cast $decompressed <: i16) < v $FIELD_MODULUS)");
a.elements[i] = decompressed as i16;
}

// debug_assert!(to_i16_array(v)
// .into_iter()
// .all(|coefficient| coefficient.abs() as u16 <= 1 << 12));

v
a
}

0 comments on commit 1d5b291

Please sign in to comment.