diff --git a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Types.Unpacked.fsti b/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Types.Unpacked.fsti deleted file mode 100644 index 1910c0b08..000000000 --- a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Types.Unpacked.fsti +++ /dev/null @@ -1,48 +0,0 @@ -module Libcrux_ml_kem.Types.Unpacked -#set-options "--fuel 0 --ifuel 1 --z3rlimit 15" -open Core -open FStar.Mul - -let _ = - (* This module has implicit dependencies, here we make them explicit. *) - (* The implicit dependencies arise from typeclasses instances. *) - let open Libcrux_ml_kem.Vector.Traits in - () - -/// An unpacked ML-KEM IND-CPA Private Key -type t_IndCpaPrivateKeyUnpacked - (v_K: usize) (v_Vector: Type0) {| i1: Libcrux_ml_kem.Vector.Traits.t_Operations v_Vector |} - = { f_secret_as_ntt:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K } - -/// An unpacked ML-KEM IND-CPA Private Key -type t_IndCpaPublicKeyUnpacked - (v_K: usize) (v_Vector: Type0) {| i1: Libcrux_ml_kem.Vector.Traits.t_Operations v_Vector |} - = { - f_t_as_ntt:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K; - f_seed_for_A:t_Array u8 (sz 32); - f_A:t_Array (t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K) v_K -} - -/// An unpacked ML-KEM IND-CCA Private Key -type t_MlKemPrivateKeyUnpacked - (v_K: usize) (v_Vector: Type0) {| i1: Libcrux_ml_kem.Vector.Traits.t_Operations v_Vector |} - = { - f_ind_cpa_private_key:t_IndCpaPrivateKeyUnpacked v_K v_Vector; - f_implicit_rejection_value:t_Array u8 (sz 32) -} - -/// An unpacked ML-KEM IND-CCA Private Key -type t_MlKemPublicKeyUnpacked - (v_K: usize) (v_Vector: Type0) {| i1: Libcrux_ml_kem.Vector.Traits.t_Operations v_Vector |} - = { - f_ind_cpa_public_key:t_IndCpaPublicKeyUnpacked v_K v_Vector; - f_public_key_hash:t_Array u8 (sz 32) -} - -/// An unpacked ML-KEM KeyPair -type t_MlKemKeyPairUnpacked - (v_K: usize) (v_Vector: Type0) {| i1: Libcrux_ml_kem.Vector.Traits.t_Operations v_Vector |} - = { - f_private_key:t_MlKemPrivateKeyUnpacked v_K v_Vector; - f_public_key:t_MlKemPublicKeyUnpacked v_K v_Vector -} diff --git a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Avx2.Portable.fsti b/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Avx2.Portable.fsti deleted file mode 100644 index fe64003c4..000000000 --- a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Avx2.Portable.fsti +++ /dev/null @@ -1,30 +0,0 @@ -module Libcrux_ml_kem.Vector.Avx2.Portable -#set-options "--fuel 0 --ifuel 1 --z3rlimit 15" -open Core -open FStar.Mul - -val deserialize_11_int (bytes: t_Slice u8) - : Prims.Pure (i16 & i16 & i16 & i16 & i16 & i16 & i16 & i16) - Prims.l_True - (fun _ -> Prims.l_True) - -val serialize_11_int (v: t_Slice i16) - : Prims.Pure (u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8) - Prims.l_True - (fun _ -> Prims.l_True) - -type t_PortableVector = { f_elements:t_Array i16 (sz 16) } - -val from_i16_array (array: t_Array i16 (sz 16)) - : Prims.Pure t_PortableVector Prims.l_True (fun _ -> Prims.l_True) - -val serialize_11_ (v: t_PortableVector) - : Prims.Pure (t_Array u8 (sz 22)) Prims.l_True (fun _ -> Prims.l_True) - -val to_i16_array (v: t_PortableVector) - : Prims.Pure (t_Array i16 (sz 16)) Prims.l_True (fun _ -> Prims.l_True) - -val zero: Prims.unit -> Prims.Pure t_PortableVector Prims.l_True (fun _ -> Prims.l_True) - -val deserialize_11_ (bytes: t_Slice u8) - : Prims.Pure t_PortableVector Prims.l_True (fun _ -> Prims.l_True) diff --git a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Serialize.Edited.fsti b/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Serialize.Edited.fsti deleted file mode 100644 index 4ed69770d..000000000 --- a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Serialize.Edited.fsti +++ /dev/null @@ -1,100 +0,0 @@ -module Libcrux_ml_kem.Vector.Portable.Serialize.Edited -// #set-options "--fuel 0 --ifuel 1 --z3rlimit 15" -// open Core -// open FStar.Mul - -// val deserialize_10_int (bytes: t_Slice u8) -// : Prims.Pure (i16 & i16 & i16 & i16 & i16 & i16 & i16 & i16) -// Prims.l_True -// (fun _ -> Prims.l_True) - -// val deserialize_11_int (bytes: t_Slice u8) -// : Prims.Pure (i16 & i16 & i16 & i16 & i16 & i16 & i16 & i16) -// Prims.l_True -// (fun _ -> Prims.l_True) - -// val deserialize_12_int (bytes: t_Slice u8) -// : Prims.Pure (i16 & i16) Prims.l_True (fun _ -> Prims.l_True) - -// val deserialize_4_int (bytes: t_Slice u8) -// : Prims.Pure (i16 & i16 & i16 & i16 & i16 & i16 & i16 & i16) -// Prims.l_True -// (fun _ -> Prims.l_True) - -// val deserialize_5_int (bytes: t_Slice u8) -// : Prims.Pure (i16 & i16 & i16 & i16 & i16 & i16 & i16 & i16) -// Prims.l_True -// (fun _ -> Prims.l_True) - -// val serialize_10_int (v: t_Slice i16) -// : Prims.Pure (u8 & u8 & u8 & u8 & u8) -// (requires (Core.Slice.impl__len #i16 v <: usize) =. sz 4) -// (ensures -// fun tuple -> -// let tuple:(u8 & u8 & u8 & u8 & u8) = tuple in -// BitVecEq.int_t_array_bitwise_eq' (v <: t_Array i16 (sz 4)) 10 (MkSeq.create5 tuple) 8) - -// val serialize_11_int (v: t_Slice i16) -// : Prims.Pure (u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8) -// (requires Seq.length v == 8 /\ (forall i. Rust_primitives.bounded (Seq.index v i) 11)) -// (ensures -// fun tuple -> -// let tuple:(u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8) = tuple in -// BitVecEq.int_t_array_bitwise_eq' (v <: t_Array i16 (sz 8)) 11 (MkSeq.create11 tuple) 8) - -// val serialize_12_int (v: t_Slice i16) -// : Prims.Pure (u8 & u8 & u8) Prims.l_True (fun _ -> Prims.l_True) - -// val serialize_4_int (v: t_Slice i16) -// : Prims.Pure (u8 & u8 & u8 & u8) Prims.l_True (fun _ -> Prims.l_True) - -// val serialize_5_int (v: t_Slice i16) -// : Prims.Pure (u8 & u8 & u8 & u8 & u8) Prims.l_True (fun _ -> Prims.l_True) - -// val serialize_1_ (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) -// : Prims.Pure (t_Array u8 (sz 2)) Prims.l_True (fun _ -> Prims.l_True) - -// val serialize_10_ (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) -// : Prims.Pure (t_Array u8 (sz 20)) Prims.l_True (fun _ -> Prims.l_True) - -// val serialize_11_ (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) -// : Prims.Pure (t_Array u8 (sz 22)) Prims.l_True (fun _ -> Prims.l_True) - -// val serialize_12_ (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) -// : Prims.Pure (t_Array u8 (sz 24)) Prims.l_True (fun _ -> Prims.l_True) - -// val serialize_4_ (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) -// : Prims.Pure (t_Array u8 (sz 8)) Prims.l_True (fun _ -> Prims.l_True) - -// val serialize_5_ (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) -// : Prims.Pure (t_Array u8 (sz 10)) Prims.l_True (fun _ -> Prims.l_True) - -// val deserialize_1_ (v: t_Slice u8) -// : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector -// Prims.l_True -// (fun _ -> Prims.l_True) - -// val deserialize_10_ (bytes: t_Slice u8) -// : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector -// Prims.l_True -// (fun _ -> Prims.l_True) - -// val deserialize_11_ (bytes: t_Slice u8) -// : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector -// Prims.l_True -// (fun _ -> Prims.l_True) - -// val deserialize_12_ (bytes: t_Slice u8) -// : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector -// Prims.l_True -// (fun _ -> Prims.l_True) - -// val deserialize_4_ (bytes: t_Slice u8) -// : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector -// Prims.l_True -// (fun _ -> Prims.l_True) - -// val deserialize_5_ (bytes: t_Slice u8) -// : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector -// Prims.l_True -// (fun _ -> Prims.l_True) diff --git a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.fst b/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.fst deleted file mode 100644 index 0ca12f7ff..000000000 --- a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.fst +++ /dev/null @@ -1,59 +0,0 @@ -module Libcrux_ml_kem.Vector.Portable -#set-options "--fuel 0 --ifuel 1 --z3rlimit 100" -open Core -open FStar.Mul - -let _ = - (* This module has implicit dependencies, here we make them explicit. *) - (* The implicit dependencies arise from typeclasses instances. *) - let open Libcrux_ml_kem.Vector.Portable.Vector_type in - let open Libcrux_ml_kem.Vector.Traits in - () - -let deserialize_11_ (a: t_Slice u8) = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_11_ a - -let deserialize_5_ (a: t_Slice u8) = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_5_ a - -let serialize_11_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) = - Libcrux_ml_kem.Vector.Portable.Serialize.serialize_11_ a - -let serialize_5_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) = - Libcrux_ml_kem.Vector.Portable.Serialize.serialize_5_ a - -let deserialize_1_ (a: t_Slice u8) = - let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_1_lemma a in - let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_1_bounded_lemma a in - Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_1_ a - -let deserialize_10_ (a: t_Slice u8) = - let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_10_lemma a in - let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_10_bounded_lemma a in - Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_10_ a - -let deserialize_12_ (a: t_Slice u8) = - let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_12_lemma a in - let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_12_bounded_lemma a in - Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_12_ a - -let deserialize_4_ (a: t_Slice u8) = - let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_4_lemma a in - let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_4_bounded_lemma a in - Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_4_ a - -let serialize_1_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) = - let _:Prims.unit = assert (forall i. Rust_primitives.bounded (Seq.index a.f_elements i) 1) in - let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.serialize_1_lemma a in - Libcrux_ml_kem.Vector.Portable.Serialize.serialize_1_ a - -let serialize_10_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) = - let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.serialize_10_lemma a in - Libcrux_ml_kem.Vector.Portable.Serialize.serialize_10_ a - -let serialize_12_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) = - let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.serialize_12_lemma a in - Libcrux_ml_kem.Vector.Portable.Serialize.serialize_12_ a - -let serialize_4_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) = - let _:Prims.unit = assert (forall i. Rust_primitives.bounded (Seq.index a.f_elements i) 4) in - let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.serialize_4_lemma a in - Libcrux_ml_kem.Vector.Portable.Serialize.serialize_4_ a diff --git a/libcrux-ml-kem/proofs/fstar/spec/Makefile b/libcrux-ml-kem/proofs/fstar/spec/Makefile deleted file mode 100644 index b4ce70a38..000000000 --- a/libcrux-ml-kem/proofs/fstar/spec/Makefile +++ /dev/null @@ -1 +0,0 @@ -include $(shell git rev-parse --show-toplevel)/fstar-helpers/Makefile.base diff --git a/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.Instances.fst b/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.Instances.fst deleted file mode 100644 index f598ee0ff..000000000 --- a/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.Instances.fst +++ /dev/null @@ -1,64 +0,0 @@ -module Spec.MLKEM.Instances -#set-options "--fuel 0 --ifuel 1 --z3rlimit 30" -open FStar.Mul -open Core -open Spec.Utils -open Spec.MLKEM.Math -open Spec.MLKEM - - -(** MLKEM-768 Instantiation *) - -let mlkem768_rank : rank = sz 3 - -#push-options "--z3rlimit 300" -let mlkem768_generate_keypair (randomness:t_Array u8 (sz 64)): - (t_Array u8 (sz 2400) & t_Array u8 (sz 1184)) & bool = - ind_cca_generate_keypair mlkem768_rank randomness - -let mlkem768_encapsulate (public_key: t_Array u8 (sz 1184)) (randomness: t_Array u8 (sz 32)): - (t_Array u8 (sz 1088) & t_Array u8 (sz 32)) & bool = - ind_cca_encapsulate mlkem768_rank public_key randomness - -let mlkem768_decapsulate (secret_key: t_Array u8 (sz 2400)) (ciphertext: t_Array u8 (sz 1088)): - t_Array u8 (sz 32) & bool = - ind_cca_decapsulate mlkem768_rank secret_key ciphertext - -(** MLKEM-1024 Instantiation *) - -let mlkem1024_rank = sz 4 - -let mlkem1024_generate_keypair (randomness:t_Array u8 (sz 64)): - (t_Array u8 (sz 3168) & t_Array u8 (sz 1568)) & bool = - ind_cca_generate_keypair mlkem1024_rank randomness - -#set-options "--z3rlimit 100" -let mlkem1024_encapsulate (public_key: t_Array u8 (sz 1568)) (randomness: t_Array u8 (sz 32)): - (t_Array u8 (sz 1568) & t_Array u8 (sz 32)) & bool = - assert (v_CPA_CIPHERTEXT_SIZE mlkem1024_rank == sz 1568); - ind_cca_encapsulate mlkem1024_rank public_key randomness - -let mlkem1024_decapsulate (secret_key: t_Array u8 (sz 3168)) (ciphertext: t_Array u8 (sz 1568)): - t_Array u8 (sz 32) & bool = - ind_cca_decapsulate mlkem1024_rank secret_key ciphertext - -(** MLKEM-512 Instantiation *) - -let mlkem512_rank : rank = sz 2 - -let mlkem512_generate_keypair (randomness:t_Array u8 (sz 64)): - (t_Array u8 (sz 1632) & t_Array u8 (sz 800)) & bool = - ind_cca_generate_keypair mlkem512_rank randomness - -let mlkem512_encapsulate (public_key: t_Array u8 (sz 800)) (randomness: t_Array u8 (sz 32)): - (t_Array u8 (sz 768) & t_Array u8 (sz 32)) & bool = - assert (v_CPA_CIPHERTEXT_SIZE mlkem512_rank == sz 768); - ind_cca_encapsulate mlkem512_rank public_key randomness - - -let mlkem512_decapsulate (secret_key: t_Array u8 (sz 1632)) (ciphertext: t_Array u8 (sz 768)): - t_Array u8 (sz 32) & bool = - ind_cca_decapsulate mlkem512_rank secret_key ciphertext - - - diff --git a/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.Math.fst b/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.Math.fst deleted file mode 100644 index 571e879fb..000000000 --- a/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.Math.fst +++ /dev/null @@ -1,293 +0,0 @@ -module Spec.MLKEM.Math -#set-options "--fuel 0 --ifuel 1 --z3rlimit 80" - -open FStar.Mul -open Core -open Spec.Utils - -let v_FIELD_MODULUS: i32 = 3329l -let is_rank (r:usize) = v r == 2 \/ v r == 3 \/ v r == 4 - -type rank = r:usize{is_rank r} - -(** MLKEM Math and Sampling *) - -type field_element = n:nat{n < v v_FIELD_MODULUS} -type polynomial = t_Array field_element (sz 256) -type vector (r:rank) = t_Array polynomial r -type matrix (r:rank) = t_Array (vector r) r - -val field_add: field_element -> field_element -> field_element -let field_add a b = (a + b) % v v_FIELD_MODULUS - -val field_sub: field_element -> field_element -> field_element -let field_sub a b = (a - b) % v v_FIELD_MODULUS - -val field_neg: field_element -> field_element -let field_neg a = (0 - a) % v v_FIELD_MODULUS - -val field_mul: field_element -> field_element -> field_element -let field_mul a b = (a * b) % v v_FIELD_MODULUS - -val poly_add: polynomial -> polynomial -> polynomial -let poly_add a b = map2 field_add a b - -val poly_sub: polynomial -> polynomial -> polynomial -let poly_sub a b = map2 field_sub a b - -let int_to_spec_fe (m:int) : field_element = - let m_v = m % v v_FIELD_MODULUS in - assert (m_v > - v v_FIELD_MODULUS); - if m_v < 0 then - m_v + v v_FIELD_MODULUS - else m_v - -(* Convert concrete code types to spec types *) - -let to_spec_fe (m:i16) : field_element = - int_to_spec_fe (v m) - -let to_spec_array #len (m:t_Array i16 len) : t_Array field_element len = - createi #field_element len (fun i -> to_spec_fe (m.[i])) - -let to_spec_poly (m:t_Array i16 (sz 256)) : polynomial = - to_spec_array m - -let to_spec_vector (#r:rank) - (m:t_Array (t_Array i16 (sz 256)) r) - : (vector r) = - createi r (fun i -> to_spec_poly (m.[i])) - -let to_spec_matrix (#r:rank) - (m:t_Array (t_Array (t_Array i16 (sz 256)) r) r) - : (matrix r) = - createi r (fun i -> to_spec_vector (m.[i])) - -(* Specifying NTT: -bitrev7 = [int('{:07b}'.format(x)[::-1], 2) for x in range(0,128)] -zetas = [pow(17,x) % 3329 for x in bitrev7] -zetas_mont = [pow(2,16) * x % 3329 for x in zetas] -zetas_mont_r = [(x - 3329 if x > 1664 else x) for x in zetas_mont] - -bitrev7 is -[0, 64, 32, 96, 16, 80, 48, 112, 8, 72, 40, 104, 24, 88, 56, 120, 4, 68, 36, 100, 20, 84, 52, 116, 12, 76, 44, 108, 28, 92, 60, 124, 2, 66, 34, 98, 18, 82, 50, 114, 10, 74, 42, 106, 26, 90, 58, 122, 6, 70, 38, 102, 22, 86, 54, 118, 14, 78, 46, 110, 30, 94, 62, 126, 1, 65, 33, 97, 17, 81, 49, 113, 9, 73, 41, 105, 25, 89, 57, 121, 5, 69, 37, 101, 21, 85, 53, 117, 13, 77, 45, 109, 29, 93, 61, 125, 3, 67, 35, 99, 19, 83, 51, 115, 11, 75, 43, 107, 27, 91, 59, 123, 7, 71, 39, 103, 23, 87, 55, 119, 15, 79, 47, 111, 31, 95, 63, 127] - -zetas = 17^bitrev7 is -[1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154] - -zetas_mont = zetas * 2^16 is -[2285, 2571, 2970, 1812, 1493, 1422, 287, 202, 3158, 622, 1577, 182, 962, 2127, 1855, 1468, 573, 2004, 264, 383, 2500, 1458, 1727, 3199, 2648, 1017, 732, 608, 1787, 411, 3124, 1758, 1223, 652, 2777, 1015, 2036, 1491, 3047, 1785, 516, 3321, 3009, 2663, 1711, 2167, 126, 1469, 2476, 3239, 3058, 830, 107, 1908, 3082, 2378, 2931, 961, 1821, 2604, 448, 2264, 677, 2054, 2226, 430, 555, 843, 2078, 871, 1550, 105, 422, 587, 177, 3094, 3038, 2869, 1574, 1653, 3083, 778, 1159, 3182, 2552, 1483, 2727, 1119, 1739, 644, 2457, 349, 418, 329, 3173, 3254, 817, 1097, 603, 610, 1322, 2044, 1864, 384, 2114, 3193, 1218, 1994, 2455, 220, 2142, 1670, 2144, 1799, 2051, 794, 1819, 2475, 2459, 478, 3221, 3021, 996, 991, 958, 1869, 1522, 1628] - -zetas_mont_r = zetas_mont - 3329 if zetas_mont > 1664 else zetas_mont is -[-1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, -308, 996, 991, 958, -1460, 1522, 1628] -*) - -let zetas_list : list field_element = [1; 1729; 2580; 3289; 2642; 630; 1897; 848; 1062; 1919; 193; 797; 2786; 3260; 569; 1746; 296; 2447; 1339; 1476; 3046; 56; 2240; 1333; 1426; 2094; 535; 2882; 2393; 2879; 1974; 821; 289; 331; 3253; 1756; 1197; 2304; 2277; 2055; 650; 1977; 2513; 632; 2865; 33; 1320; 1915; 2319; 1435; 807; 452; 1438; 2868; 1534; 2402; 2647; 2617; 1481; 648; 2474; 3110; 1227; 910; 17; 2761; 583; 2649; 1637; 723; 2288; 1100; 1409; 2662; 3281; 233; 756; 2156; 3015; 3050; 1703; 1651; 2789; 1789; 1847; 952; 1461; 2687; 939; 2308; 2437; 2388; 733; 2337; 268; 641; 1584; 2298; 2037; 3220; 375; 2549; 2090; 1645; 1063; 319; 2773; 757; 2099; 561; 2466; 2594; 2804; 1092; 403; 1026; 1143; 2150; 2775; 886; 1722; 1212; 1874; 1029; 2110; 2935; 885; 2154] - -let zetas : t_Array field_element (sz 128) = - assert_norm(List.Tot.length zetas_list == 128); - Rust_primitives.Arrays.of_list zetas_list - -let poly_ntt_step (a:field_element) (b:field_element) (i:nat{i < 128}) = - let t = field_mul b zetas.[sz i] in - let b = field_sub a t in - let a = field_add a t in - (a,b) - -#push-options "--split_queries always" -let poly_ntt_layer (p:polynomial) (l:nat{l > 0 /\ l < 8}) : polynomial = - let len = pow2 l in - let k = (128 / len) - 1 in - Rust_primitives.Arrays.createi (sz 256) (fun i -> - let round = v i / (2 * len) in - let idx = v i % (2 * len) in - let (idx0, idx1) = if idx < len then (idx, idx+len) else (idx-len,idx) in - let (a_ntt, b_ntt) = poly_ntt_step p.[sz idx0] p.[sz idx1] (round + k) in - if idx < len then a_ntt else b_ntt) -#pop-options - -val poly_ntt: polynomial -> polynomial -let poly_ntt p = - let p = poly_ntt_layer p 7 in - let p = poly_ntt_layer p 6 in - let p = poly_ntt_layer p 5 in - let p = poly_ntt_layer p 4 in - let p = poly_ntt_layer p 3 in - let p = poly_ntt_layer p 2 in - let p = poly_ntt_layer p 1 in - p - -let poly_inv_ntt_step (a:field_element) (b:field_element) (i:nat{i < 128}) = - let b_minus_a = field_sub b a in - let a = field_add a b in - let b = field_mul b_minus_a zetas.[sz i] in - (a,b) - -#push-options "--z3rlimit 150" -let poly_inv_ntt_layer (p:polynomial) (l:nat{l > 0 /\ l < 8}) : polynomial = - let len = pow2 l in - let k = (256 / len) - 1 in - Rust_primitives.Arrays.createi (sz 256) (fun i -> - let round = v i / (2 * len) in - let idx = v i % (2 * len) in - let (idx0, idx1) = if idx < len then (idx, idx+len) else (idx-len,idx) in - let (a_ntt, b_ntt) = poly_inv_ntt_step p.[sz idx0] p.[sz idx1] (k - round) in - if idx < len then a_ntt else b_ntt) -#pop-options - -val poly_inv_ntt: polynomial -> polynomial -let poly_inv_ntt p = - let p = poly_inv_ntt_layer p 1 in - let p = poly_inv_ntt_layer p 2 in - let p = poly_inv_ntt_layer p 3 in - let p = poly_inv_ntt_layer p 4 in - let p = poly_inv_ntt_layer p 5 in - let p = poly_inv_ntt_layer p 6 in - let p = poly_inv_ntt_layer p 7 in - p - -let poly_base_case_multiply (a0 a1 b0 b1 zeta:field_element) = - let c0 = field_add (field_mul a0 b0) (field_mul (field_mul a1 b1) zeta) in - let c1 = field_add (field_mul a0 b1) (field_mul a1 b0) in - (c0,c1) - -val poly_mul_ntt: polynomial -> polynomial -> polynomial -let poly_mul_ntt a b = - Rust_primitives.Arrays.createi (sz 256) (fun i -> - let a0 = a.[sz (2 * (v i / 2))] in - let a1 = a.[sz (2 * (v i / 2) + 1)] in - let b0 = b.[sz (2 * (v i / 2))] in - let b1 = b.[sz (2 * (v i / 2) + 1)] in - let zeta_4 = zetas.[sz (64 + (v i/4))] in - let zeta = if v i % 4 < 2 then zeta_4 else field_neg zeta_4 in - let (c0,c1) = poly_base_case_multiply a0 a1 b0 b1 zeta in - if v i % 2 = 0 then c0 else c1) - - -val vector_add: #r:rank -> vector r -> vector r -> vector r -let vector_add #p a b = map2 poly_add a b - -val vector_ntt: #r:rank -> vector r -> vector r -let vector_ntt #p v = map_array poly_ntt v - -val vector_inv_ntt: #r:rank -> vector r -> vector r -let vector_inv_ntt #p v = map_array poly_inv_ntt v - -val vector_mul_ntt: #r:rank -> vector r -> vector r -> vector r -let vector_mul_ntt #p a b = map2 poly_mul_ntt a b - -val vector_sum: #r:rank -> vector r -> polynomial -let vector_sum #r a = repeati (r -! sz 1) - (fun i x -> assert (v i < v r - 1); poly_add x (a.[i +! sz 1])) a.[sz 0] - -val vector_dot_product_ntt: #r:rank -> vector r -> vector r -> polynomial -let vector_dot_product_ntt a b = vector_sum (vector_mul_ntt a b) - -val matrix_transpose: #r:rank -> matrix r -> matrix r -let matrix_transpose #r m = - createi r (fun i -> - createi r (fun j -> - m.[j].[i])) - -val matrix_vector_mul_ntt: #r:rank -> matrix r -> vector r -> vector r -let matrix_vector_mul_ntt #r m v = - createi r (fun i -> vector_dot_product_ntt m.[i] v) - -val compute_As_plus_e_ntt: #r:rank -> a:matrix r -> s:vector r -> e:vector r -> vector r -let compute_As_plus_e_ntt #p a s e = vector_add (matrix_vector_mul_ntt a s) e - - - -type dT = d: nat {d = 1 \/ d = 4 \/ d = 5 \/ d = 10 \/ d = 11 \/ d = 12} -let max_d (d:dT) = if d < 12 then pow2 d else v v_FIELD_MODULUS -type field_element_d (d:dT) = n:nat{n < max_d d} -type polynomial_d (d:dT) = t_Array (field_element_d d) (sz 256) -type vector_d (r:rank) (d:dT) = t_Array (polynomial_d d) r - -let bits_to_bytes (#bytes: usize) (bv: bit_vec (v bytes * 8)) - : Pure (t_Array u8 bytes) - (requires True) - (ensures fun r -> (forall i. bit_vec_of_int_t_array r 8 i == bv i)) - = bit_vec_to_int_t_array 8 bv - -let bytes_to_bits (#bytes: usize) (r: t_Array u8 bytes) - : Pure (i: bit_vec (v bytes * 8)) - (requires True) - (ensures fun f -> (forall i. bit_vec_of_int_t_array r 8 i == f i)) - = bit_vec_of_int_t_array r 8 - -unfold let retype_bit_vector #a #b (#_:unit{a == b}) (x: a): b = x - - -let compress_d (d: dT {d <> 12}) (x: field_element): field_element_d d - = let r = (pow2 d * x + 1664) / v v_FIELD_MODULUS in - assert (r * v v_FIELD_MODULUS <= pow2 d * x + 1664); - assert (r * v v_FIELD_MODULUS <= pow2 d * (v v_FIELD_MODULUS - 1) + 1664); - Math.Lemmas.lemma_div_le (r * v v_FIELD_MODULUS) (pow2 d * (v v_FIELD_MODULUS - 1) + 1664) (v v_FIELD_MODULUS); - Math.Lemmas.cancel_mul_div r (v v_FIELD_MODULUS); - assert (r <= (pow2 d * (v v_FIELD_MODULUS - 1) + 1664) / v v_FIELD_MODULUS); - Math.Lemmas.lemma_div_mod_plus (1664 - pow2 d) (pow2 d) (v v_FIELD_MODULUS); - assert (r <= pow2 d + (1664 - pow2 d) / v v_FIELD_MODULUS); - assert (r <= pow2 d); - if r = pow2 d then 0 else r - -let decompress_d (d: dT {d <> 12}) (x: field_element_d d): field_element - = let r = (x * v v_FIELD_MODULUS + 1664) / pow2 d in - r - - -let byte_encode (d: dT) (coefficients: polynomial_d d): t_Array u8 (sz (32 * d)) - = let coefficients' : t_Array nat (sz 256) = map_array #(field_element_d d) (fun x -> x <: nat) coefficients in - bits_to_bytes #(sz (32 * d)) - (retype_bit_vector (bit_vec_of_nat_array coefficients' d)) - -let byte_decode (d: dT) (coefficients: t_Array u8 (sz (32 * d))): polynomial_d d - = let bv = bytes_to_bits coefficients in - let arr: t_Array nat (sz 256) = bit_vec_to_nat_array d (retype_bit_vector bv) in - let p: polynomial_d d = - createi (sz 256) (fun i -> - let x_f : field_element = arr.[i] % v v_FIELD_MODULUS in - assert (d < 12 ==> arr.[i] < pow2 d); - let x_m : field_element_d d = x_f in - x_m) - in - p - -let coerce_polynomial_12 (p:polynomial): polynomial_d 12 = p -let coerce_vector_12 (#r:rank) (v:vector r): vector_d r 12 = v - -let compress_then_byte_encode (d: dT {d <> 12}) (coefficients: polynomial): t_Array u8 (sz (32 * d)) - = let coefs: t_Array (field_element_d d) (sz 256) = map_array (compress_d d) coefficients - in - byte_encode d coefs - -let byte_decode_then_decompress (d: dT {d <> 12}) (b:t_Array u8 (sz (32 * d))): polynomial - = map_array (decompress_d d) (byte_decode d b) - - -(**** Definitions to move or to rework *) -let serialize_pre - (d1: dT) - (coefficients: t_Array i16 (sz 16)) - = forall i. i < 16 ==> bounded (Seq.index coefficients i) d1 - -// TODO: this is an alternative version of byte_encode -// rename to encoded bytes -#push-options "--z3rlimit 80 --split_queries always" -let serialize_post - (d1: dT) - (coefficients: t_Array i16 (sz 16) { serialize_pre d1 coefficients }) - (output: t_Array u8 (sz (d1 * 2))) - = BitVecEq.int_t_array_bitwise_eq coefficients d1 - output 8 - -// TODO: this is an alternative version of byte_decode -// rename to decoded bytes -let deserialize_post - (d1: dT) - (bytes: t_Array u8 (sz (d1 * 2))) - (output: t_Array i16 (sz 16)) - = BitVecEq.int_t_array_bitwise_eq bytes 8 - output d1 /\ - forall (i:nat). i < 16 ==> bounded (Seq.index output i) d1 -#pop-options diff --git a/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.fst b/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.fst deleted file mode 100644 index 07c9216ae..000000000 --- a/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.fst +++ /dev/null @@ -1,343 +0,0 @@ -module Spec.MLKEM -#set-options "--fuel 0 --ifuel 1 --z3rlimit 80" -open FStar.Mul -open Core - -include Spec.Utils -include Spec.MLKEM.Math - -(** ML-KEM Constants *) -let v_BITS_PER_COEFFICIENT: usize = sz 12 - -let v_COEFFICIENTS_IN_RING_ELEMENT: usize = sz 256 - -let v_BITS_PER_RING_ELEMENT: usize = sz 3072 // v_COEFFICIENTS_IN_RING_ELEMENT *! sz 12 - -let v_BYTES_PER_RING_ELEMENT: usize = sz 384 // v_BITS_PER_RING_ELEMENT /! sz 8 - -let v_CPA_KEY_GENERATION_SEED_SIZE: usize = sz 32 - -let v_H_DIGEST_SIZE: usize = sz 32 -// same as Libcrux.Digest.digest_size (Libcrux.Digest.Algorithm_Sha3_256_ <: Libcrux.Digest.t_Algorithm) - -let v_REJECTION_SAMPLING_SEED_SIZE: usize = sz 840 // sz 168 *! sz 5 - -let v_SHARED_SECRET_SIZE: usize = v_H_DIGEST_SIZE - -val v_ETA1 (r:rank) : u:usize{u == sz 3 \/ u == sz 2} -let v_ETA1 (r:rank) : usize = - if r = sz 2 then sz 3 else - if r = sz 3 then sz 2 else - if r = sz 4 then sz 2 - -let v_ETA2 (r:rank) : usize = sz 2 - -val v_VECTOR_U_COMPRESSION_FACTOR (r:rank) : u:usize{u == sz 10 \/ u == sz 11} -let v_VECTOR_U_COMPRESSION_FACTOR (r:rank) : usize = - if r = sz 2 then sz 10 else - if r = sz 3 then sz 10 else - if r = sz 4 then sz 11 - -val v_VECTOR_V_COMPRESSION_FACTOR (r:rank) : u:usize{u == sz 4 \/ u == sz 5} -let v_VECTOR_V_COMPRESSION_FACTOR (r:rank) : usize = - if r = sz 2 then sz 4 else - if r = sz 3 then sz 4 else - if r = sz 4 then sz 5 - -val v_ETA1_RANDOMNESS_SIZE (r:rank) : u:usize{u == sz 128 \/ u == sz 192} -let v_ETA1_RANDOMNESS_SIZE (r:rank) = v_ETA1 r *! sz 64 - -val v_ETA2_RANDOMNESS_SIZE (r:rank) : u:usize{u == sz 128} -let v_ETA2_RANDOMNESS_SIZE (r:rank) = v_ETA2 r *! sz 64 - -val v_RANKED_BYTES_PER_RING_ELEMENT (r:rank) : u:usize{u = sz 768 \/ u = sz 1152 \/ u = sz 1536} -let v_RANKED_BYTES_PER_RING_ELEMENT (r:rank) = r *! v_BYTES_PER_RING_ELEMENT - -let v_T_AS_NTT_ENCODED_SIZE (r:rank) = v_RANKED_BYTES_PER_RING_ELEMENT r -let v_CPA_PRIVATE_KEY_SIZE (r:rank) = v_RANKED_BYTES_PER_RING_ELEMENT r - -val v_CPA_PUBLIC_KEY_SIZE (r:rank) : u:usize{u = sz 800 \/ u = sz 1184 \/ u = sz 1568} -let v_CPA_PUBLIC_KEY_SIZE (r:rank) = v_RANKED_BYTES_PER_RING_ELEMENT r +! sz 32 - -val v_CCA_PRIVATE_KEY_SIZE (r:rank) : u:usize{u = sz 1632 \/ u = sz 2400 \/ u = sz 3168} -let v_CCA_PRIVATE_KEY_SIZE (r:rank) = - (v_CPA_PRIVATE_KEY_SIZE r +! v_CPA_PUBLIC_KEY_SIZE r +! v_H_DIGEST_SIZE +! v_SHARED_SECRET_SIZE) - -let v_CCA_PUBLIC_KEY_SIZE (r:rank) = v_CPA_PUBLIC_KEY_SIZE r - -val v_C1_BLOCK_SIZE (r:rank): u:usize{(u = sz 320 \/ u = sz 352) /\ v u == 32 * v (v_VECTOR_U_COMPRESSION_FACTOR r)} -let v_C1_BLOCK_SIZE (r:rank) = sz 32 *! v_VECTOR_U_COMPRESSION_FACTOR r - -val v_C1_SIZE (r:rank) : u:usize{(u >=. sz 640 /\ u <=. sz 1448) /\ - v u == v (v_C1_BLOCK_SIZE r) * v r} -let v_C1_SIZE (r:rank) = v_C1_BLOCK_SIZE r *! r - -val v_C2_SIZE (r:rank) : u:usize{(u = sz 128 \/ u = sz 160) /\ v u == 32 * v (v_VECTOR_V_COMPRESSION_FACTOR r)} -let v_C2_SIZE (r:rank) = sz 32 *! v_VECTOR_V_COMPRESSION_FACTOR r - -val v_CPA_CIPHERTEXT_SIZE (r:rank) : u:usize {v u = v (v_C1_SIZE r) + v (v_C2_SIZE r)} -let v_CPA_CIPHERTEXT_SIZE (r:rank) = v_C1_SIZE r +! v_C2_SIZE r - -let v_CCA_CIPHERTEXT_SIZE (r:rank) = v_CPA_CIPHERTEXT_SIZE r - -val v_IMPLICIT_REJECTION_HASH_INPUT_SIZE (r:rank): u:usize{v u == v v_SHARED_SECRET_SIZE + - v (v_CPA_CIPHERTEXT_SIZE r)} -let v_IMPLICIT_REJECTION_HASH_INPUT_SIZE (r:rank) = - v_SHARED_SECRET_SIZE +! v_CPA_CIPHERTEXT_SIZE r - -val v_KEY_GENERATION_SEED_SIZE: u:usize{u = sz 64} -let v_KEY_GENERATION_SEED_SIZE: usize = - v_CPA_KEY_GENERATION_SEED_SIZE +! - v_SHARED_SECRET_SIZE - - -(** ML-KEM Types *) - -type t_MLKEMPublicKey (r:rank) = t_Array u8 (v_CPA_PUBLIC_KEY_SIZE r) -type t_MLKEMPrivateKey (r:rank) = t_Array u8 (v_CCA_PRIVATE_KEY_SIZE r) -type t_MLKEMKeyPair (r:rank) = t_MLKEMPrivateKey r & t_MLKEMPublicKey r - -type t_MLKEMCPAPrivateKey (r:rank) = t_Array u8 (v_CPA_PRIVATE_KEY_SIZE r) -type t_MLKEMCPAKeyPair (r:rank) = t_MLKEMCPAPrivateKey r & t_MLKEMPublicKey r - -type t_MLKEMCiphertext (r:rank) = t_Array u8 (v_CPA_CIPHERTEXT_SIZE r) -type t_MLKEMSharedSecret = t_Array u8 (v_SHARED_SECRET_SIZE) - - -assume val sample_max: n:usize{v n < pow2 32 /\ v n >= 128 * 3 /\ v n % 3 = 0} - -val sample_polynomial_ntt: seed:t_Array u8 (sz 34) -> (polynomial & bool) -let sample_polynomial_ntt seed = - let randomness = v_XOF sample_max seed in - let bv = bytes_to_bits randomness in - assert (v sample_max * 8 == (((v sample_max / 3) * 2) * 12)); - let bv: bit_vec ((v (sz ((v sample_max / 3) * 2))) * 12) = retype_bit_vector bv in - let i16s = bit_vec_to_nat_array #(sz ((v sample_max / 3) * 2)) 12 bv in - assert ((v sample_max / 3) * 2 >= 256); - let poly0: polynomial = Seq.create 256 0 in - let index_t = n:nat{n <= 256} in - let (sampled, poly1) = - repeati #(index_t & polynomial) (sz ((v sample_max / 3) * 2)) - (fun i (sampled,acc) -> - if sampled < 256 then - let sample = Seq.index i16s (v i) in - if sample < 3329 then - (sampled+1, Rust_primitives.Hax.update_at acc (sz sampled) sample) - else (sampled, acc) - else (sampled, acc)) - (0,poly0) in - if sampled < 256 then poly0, false else poly1, true - -let sample_polynomial_ntt_at_index (seed:t_Array u8 (sz 32)) (i j: (x:usize{v x <= 4})) : polynomial & bool = - let seed34 = Seq.append seed (Seq.create 2 0uy) in - let seed34 = Rust_primitives.Hax.update_at seed34 (sz 32) (mk_int #u8_inttype (v i)) in - let seed34 = Rust_primitives.Hax.update_at seed34 (sz 33) (mk_int #u8_inttype (v j)) in - sample_polynomial_ntt seed34 - -val sample_matrix_A_ntt: #r:rank -> seed:t_Array u8 (sz 32) -> (matrix r & bool) -let sample_matrix_A_ntt #r seed = - let m = - createi r (fun i -> - createi r (fun j -> - let (p,b) = sample_polynomial_ntt_at_index seed i j in - p)) - in - let sufficient_randomness = - repeati r (fun i b -> - repeati r (fun j b -> - let (p,v) = sample_polynomial_ntt_at_index seed i j in - b && v) b) true in - (m, sufficient_randomness) - -assume val sample_poly_cbd: v_ETA:usize{v v_ETA == 2 \/ v v_ETA == 3} -> t_Array u8 (v_ETA *! sz 64) -> polynomial - -open Rust_primitives.Integers - -val sample_poly_cbd2: #r:rank -> seed:t_Array u8 (sz 32) -> domain_sep:usize{v domain_sep < 256} -> polynomial -let sample_poly_cbd2 #r seed domain_sep = - let prf_input = Seq.append seed (Seq.create 1 (mk_int #u8_inttype (v domain_sep))) in - let prf_output = v_PRF (v_ETA2_RANDOMNESS_SIZE r) prf_input in - sample_poly_cbd (v_ETA2 r) prf_output - -val sample_poly_cbd1: #r:rank -> seed:t_Array u8 (sz 32) -> domain_sep:usize{v domain_sep < 256} -> polynomial -let sample_poly_cbd1 #r seed domain_sep = - let prf_input = Seq.append seed (Seq.create 1 (mk_int #u8_inttype (v domain_sep))) in - let prf_output = v_PRF (v_ETA1_RANDOMNESS_SIZE r) prf_input in - sample_poly_cbd (v_ETA1 r) prf_output - -let sample_vector_cbd1 (#r:rank) (seed:t_Array u8 (sz 32)) (domain_sep:usize{v domain_sep < 2 * v r}) : vector r = - createi r (fun i -> sample_poly_cbd1 #r seed (domain_sep +! i)) - -let sample_vector_cbd2 (#r:rank) (seed:t_Array u8 (sz 32)) (domain_sep:usize{v domain_sep < 2 * v r}) : vector r = - createi r (fun i -> sample_poly_cbd2 #r seed (domain_sep +! i)) - -let sample_vector_cbd_then_ntt (#r:rank) (seed:t_Array u8 (sz 32)) (domain_sep:usize{v domain_sep < 2 * v r}) : vector r = - vector_ntt (sample_vector_cbd1 #r seed domain_sep) - -let vector_encode_12 (#r:rank) (v: vector r) : t_Array u8 (v_T_AS_NTT_ENCODED_SIZE r) - = let s: t_Array (t_Array _ (sz 384)) r = map_array (byte_encode 12) (coerce_vector_12 v) in - flatten s - -let vector_decode_12 (#r:rank) (arr: t_Array u8 (v_T_AS_NTT_ENCODED_SIZE r)): vector r - = createi r (fun block -> - let block_size = (sz (32 * 12)) in - let slice = Seq.slice arr (v block * v block_size) - (v block * v block_size + v block_size) in - byte_decode 12 slice - ) - -let compress_then_encode_message (p:polynomial) : t_Array u8 v_SHARED_SECRET_SIZE - = compress_then_byte_encode 1 p - -let decode_then_decompress_message (b:t_Array u8 v_SHARED_SECRET_SIZE): polynomial - = byte_decode_then_decompress 1 b - -let compress_then_encode_u (#r:rank) (vec: vector r): t_Array u8 (v_C1_SIZE r) - = let d = v (v_VECTOR_U_COMPRESSION_FACTOR r) in - flatten (map_array (compress_then_byte_encode d) vec) - -let decode_then_decompress_u (#r:rank) (arr: t_Array u8 (v_C1_SIZE r)): vector r - = let d = v_VECTOR_U_COMPRESSION_FACTOR r in - createi r (fun block -> - let block_size = v_C1_BLOCK_SIZE r in - let slice = Seq.slice arr (v block * v block_size) - (v block * v block_size + v block_size) in - byte_decode_then_decompress (v d) slice - ) - -let compress_then_encode_v (#r:rank): polynomial -> t_Array u8 (v_C2_SIZE r) - = compress_then_byte_encode (v (v_VECTOR_V_COMPRESSION_FACTOR r)) - -let decode_then_decompress_v (#r:rank): t_Array u8 (v_C2_SIZE r) -> polynomial - = byte_decode_then_decompress (v (v_VECTOR_V_COMPRESSION_FACTOR r)) - -(** IND-CPA Functions *) - -/// This function implements most of Algorithm 12 of the -/// NIST FIPS 203 specification; this is the MLKEM CPA-PKE key generation algorithm. -/// -/// We say "most of" since Algorithm 12 samples the required randomness within -/// the function itself, whereas this implementation expects it to be provided -/// through the `key_generation_seed` parameter. - -val ind_cpa_generate_keypair (r:rank) (randomness:t_Array u8 v_CPA_KEY_GENERATION_SEED_SIZE) : - (t_MLKEMCPAKeyPair r & bool) -let ind_cpa_generate_keypair r randomness = - let hashed = v_G randomness in - let (seed_for_A, seed_for_secret_and_error) = split hashed (sz 32) in - let (matrix_A_as_ntt, sufficient_randomness) = sample_matrix_A_ntt #r seed_for_A in - let secret_as_ntt = sample_vector_cbd_then_ntt #r seed_for_secret_and_error (sz 0) in - let error_as_ntt = sample_vector_cbd_then_ntt #r seed_for_secret_and_error r in - let t_as_ntt = compute_As_plus_e_ntt #r matrix_A_as_ntt secret_as_ntt error_as_ntt in - let public_key_serialized = Seq.append (vector_encode_12 #r t_as_ntt) seed_for_A in - let secret_key_serialized = vector_encode_12 #r secret_as_ntt in - ((secret_key_serialized,public_key_serialized), sufficient_randomness) - -/// This function implements Algorithm 13 of the -/// NIST FIPS 203 specification; this is the MLKEM CPA-PKE encryption algorithm. - -val ind_cpa_encrypt (r:rank) (public_key: t_MLKEMPublicKey r) - (message: t_Array u8 v_SHARED_SECRET_SIZE) - (randomness:t_Array u8 v_SHARED_SECRET_SIZE) : - (t_MLKEMCiphertext r & bool) - -[@ "opaque_to_smt"] -let ind_cpa_encrypt r public_key message randomness = - let (t_as_ntt_bytes, seed_for_A) = split public_key (v_T_AS_NTT_ENCODED_SIZE r) in - let t_as_ntt = vector_decode_12 #r t_as_ntt_bytes in - let matrix_A_as_ntt, sufficient_randomness = sample_matrix_A_ntt #r seed_for_A in - let r_as_ntt = sample_vector_cbd_then_ntt #r randomness (sz 0) in - let error_1 = sample_vector_cbd2 #r randomness r in - let error_2 = sample_poly_cbd2 #r randomness (r +! r) in - let u = vector_add (vector_inv_ntt (matrix_vector_mul_ntt (matrix_transpose matrix_A_as_ntt) r_as_ntt)) error_1 in - let mu = decode_then_decompress_message message in - let v = poly_add (poly_add (vector_dot_product_ntt t_as_ntt r_as_ntt) error_2) mu in - let c1 = compress_then_encode_u #r u in - let c2 = compress_then_encode_v #r v in - (concat c1 c2, sufficient_randomness) - -/// This function implements Algorithm 14 of the -/// NIST FIPS 203 specification; this is the MLKEM CPA-PKE decryption algorithm. - -val ind_cpa_decrypt (r:rank) (secret_key: t_MLKEMCPAPrivateKey r) - (ciphertext: t_MLKEMCiphertext r): - t_MLKEMSharedSecret - -[@ "opaque_to_smt"] -let ind_cpa_decrypt r secret_key ciphertext = - let (c1,c2) = split ciphertext (v_C1_SIZE r) in - let u = decode_then_decompress_u #r c1 in - let v = decode_then_decompress_v #r c2 in - let secret_as_ntt = vector_decode_12 #r secret_key in - let w = poly_sub v (poly_inv_ntt (vector_dot_product_ntt secret_as_ntt (vector_ntt u))) in - compress_then_encode_message w - -(** IND-CCA Functions *) - - -/// This function implements most of Algorithm 15 of the -/// NIST FIPS 203 specification; this is the MLKEM CCA-KEM key generation algorithm. -/// -/// We say "most of" since Algorithm 15 samples the required randomness within -/// the function itself, whereas this implementation expects it to be provided -/// through the `randomness` parameter. -/// -/// TODO: input validation - -val ind_cca_generate_keypair (r:rank) (randomness:t_Array u8 v_KEY_GENERATION_SEED_SIZE) : - t_MLKEMKeyPair r & bool -let ind_cca_generate_keypair p randomness = - let (ind_cpa_keypair_randomness, implicit_rejection_value) = - split randomness v_CPA_KEY_GENERATION_SEED_SIZE in - - let (ind_cpa_secret_key,ind_cpa_public_key), sufficient_randomness = ind_cpa_generate_keypair p ind_cpa_keypair_randomness in - let ind_cca_secret_key = Seq.append ind_cpa_secret_key ( - Seq.append ind_cpa_public_key ( - Seq.append (v_H ind_cpa_public_key) implicit_rejection_value)) in - (ind_cca_secret_key, ind_cpa_public_key), sufficient_randomness - -/// This function implements most of Algorithm 16 of the -/// NIST FIPS 203 specification; this is the MLKEM CCA-KEM encapsulation algorithm. -/// -/// We say "most of" since Algorithm 16 samples the required randomness within -/// the function itself, whereas this implementation expects it to be provided -/// through the `randomness` parameter. -/// -/// TODO: input validation - -val ind_cca_encapsulate (r:rank) (public_key: t_MLKEMPublicKey r) - (randomness:t_Array u8 v_SHARED_SECRET_SIZE) : - (t_MLKEMCiphertext r & t_MLKEMSharedSecret) & bool -let ind_cca_encapsulate p public_key randomness = - let to_hash = concat randomness (v_H public_key) in - let hashed = v_G to_hash in - let (shared_secret, pseudorandomness) = split hashed v_SHARED_SECRET_SIZE in - let ciphertext, sufficient_randomness = ind_cpa_encrypt p public_key randomness pseudorandomness in - (ciphertext,shared_secret), sufficient_randomness - - -/// This function implements Algorithm 17 of the -/// NIST FIPS 203 specification; this is the MLKEM CCA-KEM encapsulation algorithm. - -val ind_cca_decapsulate (r:rank) (secret_key: t_MLKEMPrivateKey r) - (ciphertext: t_MLKEMCiphertext r): - t_MLKEMSharedSecret & bool -let ind_cca_decapsulate p secret_key ciphertext = - let (ind_cpa_secret_key,rest) = split secret_key (v_CPA_PRIVATE_KEY_SIZE p) in - let (ind_cpa_public_key,rest) = split rest (v_CPA_PUBLIC_KEY_SIZE p) in - let (ind_cpa_public_key_hash,implicit_rejection_value) = split rest v_H_DIGEST_SIZE in - - let decrypted = ind_cpa_decrypt p ind_cpa_secret_key ciphertext in - let to_hash = concat decrypted ind_cpa_public_key_hash in - let hashed = v_G to_hash in - let (success_shared_secret, pseudorandomness) = split hashed v_SHARED_SECRET_SIZE in - - assert (Seq.length implicit_rejection_value = 32); - let to_hash = concat implicit_rejection_value ciphertext in - let rejection_shared_secret = v_J to_hash in - - let reencrypted, sufficient_randomness = ind_cpa_encrypt p ind_cpa_public_key decrypted pseudorandomness in - if reencrypted = ciphertext - then success_shared_secret, sufficient_randomness - else rejection_shared_secret, sufficient_randomness - diff --git a/libcrux-ml-kem/proofs/fstar/spec/Spec.Utils.fst b/libcrux-ml-kem/proofs/fstar/spec/Spec.Utils.fst deleted file mode 100644 index 1c6ed14b1..000000000 --- a/libcrux-ml-kem/proofs/fstar/spec/Spec.Utils.fst +++ /dev/null @@ -1,493 +0,0 @@ -module Spec.Utils -#set-options "--fuel 0 --ifuel 1 --z3rlimit 100" -open FStar.Mul -open Core - -(** Utils *) -let map_slice #a #b - (f:(x:a -> b)) - (s: t_Slice a): t_Slice b - = createi (length s) (fun i -> f (Seq.index s (v i))) - -let map_array #a #b #len - (f:(x:a -> b)) - (s: t_Array a len): t_Array b len - = createi (length s) (fun i -> f (Seq.index s (v i))) - -let map2 #a #b #c #len - (f:a -> b -> c) - (x: t_Array a len) (y: t_Array b len): t_Array c len - = createi (length x) (fun i -> f (Seq.index x (v i)) (Seq.index y (v i))) - -let create len c = createi len (fun i -> c) - -let repeati #acc (l:usize) (f:(i:usize{v i < v l}) -> acc -> acc) acc0 : acc = Lib.LoopCombinators.repeati (v l) (fun i acc -> f (sz i) acc) acc0 - -let createL len l = Rust_primitives.Hax.array_of_list len l - -let create16 v15 v14 v13 v12 v11 v10 v9 v8 v7 v6 v5 v4 v3 v2 v1 v0 = - let l = [v15; v14; v13; v12; v11; v10; v9; v8; v7; v6; v5; v4; v3; v2; v1; v0] in - assert_norm (List.Tot.length l == 16); - createL 16 l - - -val lemma_createL_index #a len l i : - Lemma (Seq.index (createL #a len l) i == List.Tot.index l i) - [SMTPat (Seq.index (createL #a len l) i)] -let lemma_createL_index #a len l i = () - -val lemma_create16_index #a v15 v14 v13 v12 v11 v10 v9 v8 v7 v6 v5 v4 v3 v2 v1 v0 i : - Lemma (Seq.index (create16 #a v15 v14 v13 v12 v11 v10 v9 v8 v7 v6 v5 v4 v3 v2 v1 v0) i == - (if i = 0 then v15 else - if i = 1 then v14 else - if i = 2 then v13 else - if i = 3 then v12 else - if i = 4 then v11 else - if i = 5 then v10 else - if i = 6 then v9 else - if i = 7 then v8 else - if i = 8 then v7 else - if i = 9 then v6 else - if i = 10 then v5 else - if i = 11 then v4 else - if i = 12 then v3 else - if i = 13 then v2 else - if i = 14 then v1 else - if i = 15 then v0)) - [SMTPat (Seq.index (create16 #a v15 v14 v13 v12 v11 v10 v9 v8 v7 v6 v5 v4 v3 v2 v1 v0) i)] -let lemma_create16_index #a v15 v14 v13 v12 v11 v10 v9 v8 v7 v6 v5 v4 v3 v2 v1 v0 i = - let l = [v15; v14; v13; v12; v11; v10; v9; v8; v7; v6; v5; v4; v3; v2; v1; v0] in - assert_norm (List.Tot.index l 0 == v15); - assert_norm (List.Tot.index l 1 == v14); - assert_norm (List.Tot.index l 2 == v13); - assert_norm (List.Tot.index l 3 == v12); - assert_norm (List.Tot.index l 4 == v11); - assert_norm (List.Tot.index l 5 == v10); - assert_norm (List.Tot.index l 6 == v9); - assert_norm (List.Tot.index l 7 == v8); - assert_norm (List.Tot.index l 8 == v7); - assert_norm (List.Tot.index l 9 == v6); - assert_norm (List.Tot.index l 10 == v5); - assert_norm (List.Tot.index l 11 == v4); - assert_norm (List.Tot.index l 12 == v3); - assert_norm (List.Tot.index l 13 == v2); - assert_norm (List.Tot.index l 14 == v1); - assert_norm (List.Tot.index l 15 == v0) - - -val lemma_createi_index #a len f i : - Lemma (Seq.index (createi #a len f) i == f (sz i)) - [SMTPat (Seq.index (createi #a len f) i)] -let lemma_createi_index #a len f i = () - -val lemma_create_index #a len c i: - Lemma (Seq.index (create #a len c) i == c) - [SMTPat (Seq.index (create #a len c) i)] -let lemma_create_index #a len c i = () - -val lemma_map_index #a #b #len f x i: - Lemma (Seq.index (map_array #a #b #len f x) i == f (Seq.index x i)) - [SMTPat (Seq.index (map_array #a #b #len f x) i)] -let lemma_map_index #a #b #len f x i = () - -val lemma_map2_index #a #b #c #len f x y i: - Lemma (Seq.index (map2 #a #b #c #len f x y) i == f (Seq.index x i) (Seq.index y i)) - [SMTPat (Seq.index (map2 #a #b #c #len f x y) i)] -let lemma_map2_index #a #b #c #len f x y i = () - -let lemma_bitand_properties #t (x:int_t t) : - Lemma ((x &. ones) == x /\ (x &. mk_int #t 0) == mk_int #t 0 /\ (ones #t &. x) == x /\ (mk_int #t 0 &. x) == mk_int #t 0) = - logand_lemma #t x x - -#push-options "--z3rlimit 250" -let flatten #t #n - (#m: usize {range (v n * v m) usize_inttype}) - (x: t_Array (t_Array t m) n) - : t_Array t (m *! n) - = createi (m *! n) (fun i -> Seq.index (Seq.index x (v i / v m)) (v i % v m)) -#pop-options - -type t_Error = | Error_RejectionSampling : t_Error - -type t_Result a b = - | Ok: a -> t_Result a b - | Err: b -> t_Result a b - -(** Hash Function *) -open Spec.SHA3 - -val v_G (input: t_Slice u8) : t_Array u8 (sz 64) -let v_G input = map_slice Lib.RawIntTypes.u8_to_UInt8 (sha3_512 (Seq.length input) (map_slice Lib.IntTypes.secret input)) - -val v_H (input: t_Slice u8) : t_Array u8 (sz 32) -let v_H input = map_slice Lib.RawIntTypes.u8_to_UInt8 (sha3_256 (Seq.length input) (map_slice Lib.IntTypes.secret input)) - -val v_PRF (v_LEN: usize{v v_LEN < pow2 32}) (input: t_Slice u8) : t_Array u8 v_LEN -let v_PRF v_LEN input = map_slice Lib.RawIntTypes.u8_to_UInt8 ( - shake256 (Seq.length input) (map_slice Lib.IntTypes.secret input) (v v_LEN)) - -let v_J (input: t_Slice u8) : t_Array u8 (sz 32) = v_PRF (sz 32) input - -val v_XOF (v_LEN: usize{v v_LEN < pow2 32}) (input: t_Slice u8) : t_Array u8 v_LEN -let v_XOF v_LEN input = map_slice Lib.RawIntTypes.u8_to_UInt8 ( - shake128 (Seq.length input) (map_slice Lib.IntTypes.secret input) (v v_LEN)) - -let update_at_range_lemma #n - (s: t_Slice 't) - (i: Core.Ops.Range.t_Range (int_t n) {(Core.Ops.Range.impl_index_range_slice 't n).f_index_pre s i}) - (x: t_Slice 't) - : Lemma - (requires (Seq.length x == v i.f_end - v i.f_start)) - (ensures ( - let s' = Rust_primitives.Hax.Monomorphized_update_at.update_at_range s i x in - let len = v i.f_start in - forall (i: nat). i < len ==> Seq.index s i == Seq.index s' i - )) - [SMTPat (Rust_primitives.Hax.Monomorphized_update_at.update_at_range s i x)] - = let s' = Rust_primitives.Hax.Monomorphized_update_at.update_at_range s i x in - let len = v i.f_start in - introduce forall (i:nat {i < len}). Seq.index s i == Seq.index s' i - with (assert ( Seq.index (Seq.slice s 0 len) i == Seq.index s i - /\ Seq.index (Seq.slice s' 0 len) i == Seq.index s' i )) - - -/// Bounded integers - -let is_intb (l:nat) (x:int) = (x <= l) && (x >= -l) -let is_i16b (l:nat) (x:i16) = is_intb l (v x) -let is_i16b_array (l:nat) (x:t_Slice i16) = forall i. i < Seq.length x ==> is_i16b l (Seq.index x i) -let is_i16b_vector (l:nat) (r:usize) (x:t_Array (t_Array i16 (sz 256)) r) = forall i. i < v r ==> is_i16b_array l (Seq.index x i) -let is_i16b_matrix (l:nat) (r:usize) (x:t_Array (t_Array (t_Array i16 (sz 256)) r) r) = forall i. i < v r ==> is_i16b_vector l r (Seq.index x i) - -[@ "opaque_to_smt"] -let is_i16b_array_opaque (l:nat) (x:t_Slice i16) = is_i16b_array l x - -let is_i32b (l:nat) (x:i32) = is_intb l (v x) -let is_i32b_array (l:nat) (x:t_Slice i32) = forall i. i < Seq.length x ==> is_i32b l (Seq.index x i) - -let nat_div_ceil (x:nat) (y:pos) : nat = if (x % y = 0) then x/y else (x/y)+1 - -val lemma_intb_le b b' - : Lemma (requires (b <= b')) - (ensures (forall n. is_intb b n ==> is_intb b' n)) -let lemma_intb_le b b' = () - -#push-options "--z3rlimit 200" -val lemma_mul_intb (b1 b2: nat) (n1 n2: int) - : Lemma (requires (is_intb b1 n1 /\ is_intb b2 n2)) - (ensures (is_intb (b1 * b2) (n1 * n2))) -let lemma_mul_intb (b1 b2: nat) (n1 n2: int) = - if n1 = 0 || n2 = 0 - then () - else - let open FStar.Math.Lemmas in - lemma_abs_bound n1 b1; - lemma_abs_bound n2 b2; - lemma_abs_mul n1 n2; - lemma_mult_le_left (abs n1) (abs n2) b2; - lemma_mult_le_right b2 (abs n1) b1; - lemma_abs_bound (n1 * n2) (b1 * b2) -#pop-options - -#push-options "--z3rlimit 200" -val lemma_mul_i16b (b1 b2: nat) (n1 n2: i16) - : Lemma (requires (is_i16b b1 n1 /\ is_i16b b2 n2 /\ b1 * b2 < pow2 31)) - (ensures (range (v n1 * v n2) i32_inttype /\ - is_i32b (b1 * b2) ((cast n1 <: i32) *! (cast n2 <: i32)) /\ - v ((cast n1 <: i32) *! (cast n2 <: i32)) == v n1 * v n2)) - -let lemma_mul_i16b (b1 b2: nat) (n1 n2: i16) = - if v n1 = 0 || v n2 = 0 - then () - else - let open FStar.Math.Lemmas in - lemma_abs_bound (v n1) b1; - lemma_abs_bound (v n2) b2; - lemma_abs_mul (v n1) (v n2); - lemma_mult_le_left (abs (v n1)) (abs (v n2)) b2; - lemma_mult_le_right b2 (abs (v n1)) b1; - lemma_abs_bound (v n1 * v n2) (b1 * b2) -#pop-options - -val lemma_add_i16b (b1 b2:nat) (n1 n2:i16) : - Lemma (requires (is_i16b b1 n1 /\ is_i16b b2 n2 /\ b1 + b2 < pow2 15)) - (ensures (range (v n1 + v n2) i16_inttype /\ - is_i16b (b1 + b2) (n1 +! n2))) -let lemma_add_i16b (b1 b2:nat) (n1 n2:i16) = () - -#push-options "--z3rlimit 100 --split_queries always" -let lemma_range_at_percent (v:int) (p:int{p>0/\ p%2=0 /\ v < p/2 /\ v >= -p / 2}): - Lemma (v @% p == v) = - let m = v % p in - if v < 0 then ( - Math.Lemmas.lemma_mod_plus v 1 p; - assert ((v + p) % p == v % p); - assert (v + p >= 0); - assert (v + p < p); - Math.Lemmas.modulo_lemma (v+p) p; - assert (m == v + p); - assert (m >= p/2); - assert (v @% p == m - p); - assert (v @% p == v)) - else ( - assert (v >= 0 /\ v < p); - Math.Lemmas.modulo_lemma v p; - assert (v % p == v); - assert (m < p/2); - assert (v @% p == v) - ) -#pop-options - -val lemma_sub_i16b (b1 b2:nat) (n1 n2:i16) : - Lemma (requires (is_i16b b1 n1 /\ is_i16b b2 n2 /\ b1 + b2 < pow2 15)) - (ensures (range (v n1 - v n2) i16_inttype /\ - is_i16b (b1 + b2) (n1 -. n2) /\ - v (n1 -. n2) == v n1 - v n2)) -let lemma_sub_i16b (b1 b2:nat) (n1 n2:i16) = () - -let mont_mul_red_i16 (x:i16) (y:i16) : i16= - let vlow = x *. y in - let k = vlow *. (neg 3327s) in - let k_times_modulus = cast (((cast k <: i32) *. 3329l) >>! 16l) <: i16 in - let vhigh = cast (((cast x <: i32) *. (cast y <: i32)) >>! 16l) <: i16 in - vhigh -. k_times_modulus - -let mont_red_i32 (x:i32) : i16 = - let vlow = cast x <: i16 in - let k = vlow *. (neg 3327s) in - let k_times_modulus = cast (((cast k <: i32) *. 3329l) >>! 16l) <: i16 in - let vhigh = cast (x >>! 16l) <: i16 in - vhigh -. k_times_modulus - -#push-options "--z3rlimit 100" -let lemma_at_percent_mod (v:int) (p:int{p>0/\ p%2=0}): - Lemma ((v @% p) % p == v % p) = - let m = v % p in - assert (m >= 0 /\ m < p); - if m >= p/2 then ( - assert ((v @%p) % p == (m - p) %p); - Math.Lemmas.lemma_mod_plus m (-1) p; - assert ((v @%p) % p == m %p); - Math.Lemmas.lemma_mod_mod m v p; - assert ((v @%p) % p == v % p) - ) else ( - assert ((v @%p) % p == m%p); - Math.Lemmas.lemma_mod_mod m v p; - assert ((v @%p) % p == v % p) - ) -#pop-options - -let lemma_div_at_percent (v:int) (p:int{p>0/\ p%2=0 /\ (v/p) < p/2 /\ (v/p) >= -p / 2}): - Lemma ((v / p) @% p == v / p) = - lemma_range_at_percent (v/p) p - -val lemma_mont_red_i32 (x:i32): Lemma - (requires (is_i32b (3328 * pow2 16) x)) - (ensures ( - let result:i16 = mont_red_i32 x in - is_i16b (3328 + 1665) result /\ - (is_i32b (3328 * pow2 15) x ==> is_i16b 3328 result) /\ - v result % 3329 == (v x * 169) % 3329)) - -let lemma_mont_red_i32 (x:i32) = - let vlow = cast x <: i16 in - assert (v vlow == v x @% pow2 16); - let k = vlow *. (neg 3327s) in - assert (v k == ((v x @% pow2 16) * (- 3327)) @% pow2 16); - let k_times_modulus = (cast k <: i32) *. 3329l in - assert (v k_times_modulus == (v k * 3329)); - let c = cast (k_times_modulus >>! 16l) <: i16 in - assert (v c == (((v k * 3329) / pow2 16) @% pow2 16)); - lemma_div_at_percent (v k * 3329) (pow2 16); - assert (v c == (((v k * 3329) / pow2 16))); - assert (is_i16b 1665 c); - let vhigh = cast (x >>! 16l) <: i16 in - lemma_div_at_percent (v x) (pow2 16); - assert (v vhigh == v x / pow2 16); - assert (is_i16b 3328 vhigh); - let result = vhigh -. c in - lemma_sub_i16b 3328 1665 vhigh c; - assert (is_i16b (3328 + 1665) result); - assert (v result = v vhigh - v c); - assert (is_i16b (3328 + 1665) result); - assert (is_i32b (3328 * pow2 15) x ==> is_i16b 3328 result); - calc ( == ) { - v k_times_modulus % pow2 16; - ( == ) { assert (v k_times_modulus == v k * 3329) } - (v k * 3329) % pow2 16; - ( == ) { assert (v k = ((v x @% pow2 16) * (-3327)) @% pow2 16) } - ((((v x @% pow2 16) * (-3327)) @% pow2 16) * 3329) % pow2 16; - ( == ) { Math.Lemmas.lemma_mod_mul_distr_l (((v x @% pow2 16) * (-3327)) @% pow2 16) 3329 (pow2 16) } - (((((v x @% pow2 16) * (-3327)) @% pow2 16) % pow2 16) * 3329) % pow2 16; - ( == ) { lemma_at_percent_mod ((v x @% pow2 16) * (-3327)) (pow2 16)} - ((((v x @% pow2 16) * (-3327)) % pow2 16) * 3329) % pow2 16; - ( == ) { Math.Lemmas.lemma_mod_mul_distr_l ((v x @% pow2 16) * (-3327)) 3329 (pow2 16) } - (((v x @% pow2 16) * (-3327)) * 3329) % pow2 16; - ( == ) { } - ((v x @% pow2 16) * (-3327 * 3329)) % pow2 16; - ( == ) { Math.Lemmas.lemma_mod_mul_distr_r (v x @% pow2 16) (-3327 * 3329) (pow2 16) } - ((v x @% pow2 16) % pow2 16); - ( == ) { lemma_at_percent_mod (v x) (pow2 16) } - (v x) % pow2 16; - }; - Math.Lemmas.modulo_add (pow2 16) (- (v k_times_modulus)) (v x) (v k_times_modulus); - assert ((v x - v k_times_modulus) % pow2 16 == 0); - calc ( == ) { - v result % 3329; - ( == ) { } - (v x / pow2 16 - v k_times_modulus / pow2 16) % 3329; - ( == ) { Math.Lemmas.lemma_div_exact (v x - v k_times_modulus) (pow2 16) } - ((v x - v k_times_modulus) / pow2 16) % 3329; - ( == ) { assert ((pow2 16 * 169) % 3329 == 1) } - (((v x - v k_times_modulus) / pow2 16) * ((pow2 16 * 169) % 3329)) % 3329; - ( == ) { Math.Lemmas.lemma_mod_mul_distr_r ((v x - v k_times_modulus) / pow2 16) - (pow2 16 * 169) - 3329 } - (((v x - v k_times_modulus) / pow2 16) * pow2 16 * 169) % 3329; - ( == ) { Math.Lemmas.lemma_div_exact (v x - v k_times_modulus) (pow2 16) } - ((v x - v k_times_modulus) * 169) % 3329; - ( == ) { assert (v k_times_modulus == v k * 3329) } - ((v x * 169) - (v k * 3329 * 169)) % 3329; - ( == ) { Math.Lemmas.lemma_mod_sub (v x * 169) 3329 (v k * 169) } - (v x * 169) % 3329; - } - -val lemma_mont_mul_red_i16_int (x y:i16): Lemma - (requires (is_intb (3326 * pow2 15) (v x * v y))) - (ensures ( - let result:i16 = mont_mul_red_i16 x y in - is_i16b 3328 result /\ - v result % 3329 == (v x * v y * 169) % 3329)) - -let lemma_mont_mul_red_i16_int (x y:i16) = - let vlow = x *. y in - let prod = v x * v y in - assert (v vlow == prod @% pow2 16); - let k = vlow *. (neg 3327s) in - assert (v k == (((prod) @% pow2 16) * (- 3327)) @% pow2 16); - let k_times_modulus = (cast k <: i32) *. 3329l in - assert (v k_times_modulus == (v k * 3329)); - let c = cast (k_times_modulus >>! 16l) <: i16 in - assert (v c == (((v k * 3329) / pow2 16) @% pow2 16)); - lemma_div_at_percent (v k * 3329) (pow2 16); - assert (v c == (((v k * 3329) / pow2 16))); - assert (is_i16b 1665 c); - let vhigh = cast (((cast x <: i32) *. (cast y <: i32)) >>! 16l) <: i16 in - assert (v x @% pow2 32 == v x); - assert (v y @% pow2 32 == v y); - assert (v ((cast x <: i32) *. (cast y <: i32)) == (v x * v y) @% pow2 32); - assert (v vhigh == (((prod) @% pow2 32) / pow2 16) @% pow2 16); - assert_norm (pow2 15 * 3326 < pow2 31); - lemma_range_at_percent prod (pow2 32); - assert (v vhigh == (prod / pow2 16) @% pow2 16); - lemma_div_at_percent prod (pow2 16); - assert (v vhigh == prod / pow2 16); - let result = vhigh -. c in - assert (is_i16b 1663 vhigh); - lemma_sub_i16b 1663 1665 vhigh c; - assert (is_i16b 3328 result); - assert (v result = v vhigh - v c); - calc ( == ) { - v k_times_modulus % pow2 16; - ( == ) { assert (v k_times_modulus == v k * 3329) } - (v k * 3329) % pow2 16; - ( == ) { assert (v k = ((prod @% pow2 16) * (-3327)) @% pow2 16) } - ((((prod @% pow2 16) * (-3327)) @% pow2 16) * 3329) % pow2 16; - ( == ) { Math.Lemmas.lemma_mod_mul_distr_l (((prod @% pow2 16) * (-3327)) @% pow2 16) 3329 (pow2 16) } - (((((prod @% pow2 16) * (-3327)) @% pow2 16) % pow2 16) * 3329) % pow2 16; - ( == ) { lemma_at_percent_mod ((prod @% pow2 16) * (-3327)) (pow2 16)} - ((((prod @% pow2 16) * (-3327)) % pow2 16) * 3329) % pow2 16; - ( == ) { Math.Lemmas.lemma_mod_mul_distr_l ((prod @% pow2 16) * (-3327)) 3329 (pow2 16) } - (((prod @% pow2 16) * (-3327)) * 3329) % pow2 16; - ( == ) { } - ((prod @% pow2 16) * (-3327 * 3329)) % pow2 16; - ( == ) { Math.Lemmas.lemma_mod_mul_distr_r (prod @% pow2 16) (-3327 * 3329) (pow2 16) } - ((prod @% pow2 16) % pow2 16); - ( == ) { lemma_at_percent_mod (prod) (pow2 16) } - (prod) % pow2 16; - }; - Math.Lemmas.modulo_add (pow2 16) (- (v k_times_modulus)) ((prod)) (v k_times_modulus); - assert (((prod) - v k_times_modulus) % pow2 16 == 0); - calc ( == ) { - v result % 3329; - ( == ) { } - (((prod) / pow2 16) - ((v k * 3329) / pow2 16)) % 3329; - ( == ) { Math.Lemmas.lemma_div_exact ((prod) - (v k * 3329)) (pow2 16) } - ((prod - (v k * 3329)) / pow2 16) % 3329; - ( == ) { assert ((pow2 16 * 169) % 3329 == 1) } - (((prod - (v k * 3329)) / pow2 16) * ((pow2 16 * 169) % 3329)) % 3329; - ( == ) { Math.Lemmas.lemma_mod_mul_distr_r (((prod) - (v k * 3329)) / pow2 16) - (pow2 16 * 169) - 3329 } - ((((prod) - (v k * 3329)) / pow2 16) * pow2 16 * 169) % 3329; - ( == ) { Math.Lemmas.lemma_div_exact ((prod) - (v k * 3329)) (pow2 16) } - (((prod) - (v k * 3329)) * 169) % 3329; - ( == ) { Math.Lemmas.lemma_mod_sub ((prod) * 169) 3329 (v k * 169)} - ((prod) * 169) % 3329; - } - - -val lemma_mont_mul_red_i16 (x y:i16): Lemma - (requires (is_i16b 1664 y \/ is_intb (3326 * pow2 15) (v x * v y))) - (ensures ( - let result:i16 = mont_mul_red_i16 x y in - is_i16b 3328 result /\ - v result % 3329 == (v x * v y * 169) % 3329)) - [SMTPat (mont_mul_red_i16 x y)] -let lemma_mont_mul_red_i16 x y = - if is_i16b 1664 y then ( - lemma_mul_intb (pow2 15) 1664 (v x) (v y); - assert(is_intb (3326 * pow2 15) (v x * v y)); - lemma_mont_mul_red_i16_int x y) - else lemma_mont_mul_red_i16_int x y - -let barrett_red (x:i16) = - let t1 = cast (((cast x <: i32) *. (cast 20159s <: i32)) >>! 16l) <: i16 in - let t2 = t1 +. 512s in - let q = t2 >>! 10l in - let qm = q *. 3329s in - x -. qm - -let lemma_barrett_red (x:i16) : Lemma - (requires (is_i16b 28296 x)) - (ensures (let result = barrett_red x in - is_i16b 3328 result /\ - v result % 3329 == v x % 3329)) - [SMTPat (barrett_red x)] - = admit() - -let cond_sub (x:i16) = - let xm = x -. 3329s in - let mask = xm >>! 15l in - let mm = mask &. 3329s in - xm +. mm - -let lemma_cond_sub x: - Lemma (let r = cond_sub x in - if x >=. 3329s then r == x -! 3329s else r == x) - [SMTPat (cond_sub x)] - = admit() - - -let lemma_shift_right_15_i16 (x:i16): - Lemma (if v x >= 0 then (x >>! 15l) == 0s else (x >>! 15l) == -1s) = - Rust_primitives.Integers.mk_int_v_lemma #i16_inttype 0s; - Rust_primitives.Integers.mk_int_v_lemma #i16_inttype (-1s); - () - -val ntt_spec #len (vec_in: t_Array i16 len) (zeta: int) (i: nat{i < v len}) (j: nat{j < v len}) - (vec_out: t_Array i16 len) : Type0 -let ntt_spec vec_in zeta i j vec_out = - ((v (Seq.index vec_out i) % 3329) == - ((v (Seq.index vec_in i) + (v (Seq.index vec_in j) * zeta * 169)) % 3329)) /\ - ((v (Seq.index vec_out j) % 3329) == - ((v (Seq.index vec_in i) - (v (Seq.index vec_in j) * zeta * 169)) % 3329)) - -val inv_ntt_spec #len (vec_in: t_Array i16 len) (zeta: int) (i: nat{i < v len}) (j: nat{j < v len}) - (vec_out: t_Array i16 len) : Type0 -let inv_ntt_spec vec_in zeta i j vec_out = - ((v (Seq.index vec_out i) % 3329) == - ((v (Seq.index vec_in j) + v (Seq.index vec_in i)) % 3329)) /\ - ((v (Seq.index vec_out j) % 3329) == - (((v (Seq.index vec_in j) - v (Seq.index vec_in i)) * zeta * 169) % 3329)) -