diff --git a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Types.Unpacked.fsti b/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Types.Unpacked.fsti
deleted file mode 100644
index 1910c0b08..000000000
--- a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Types.Unpacked.fsti
+++ /dev/null
@@ -1,48 +0,0 @@
-module Libcrux_ml_kem.Types.Unpacked
-#set-options "--fuel 0 --ifuel 1 --z3rlimit 15"
-open Core
-open FStar.Mul
-
-let _ =
- (* This module has implicit dependencies, here we make them explicit. *)
- (* The implicit dependencies arise from typeclasses instances. *)
- let open Libcrux_ml_kem.Vector.Traits in
- ()
-
-/// An unpacked ML-KEM IND-CPA Private Key
-type t_IndCpaPrivateKeyUnpacked
- (v_K: usize) (v_Vector: Type0) {| i1: Libcrux_ml_kem.Vector.Traits.t_Operations v_Vector |}
- = { f_secret_as_ntt:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K }
-
-/// An unpacked ML-KEM IND-CPA Private Key
-type t_IndCpaPublicKeyUnpacked
- (v_K: usize) (v_Vector: Type0) {| i1: Libcrux_ml_kem.Vector.Traits.t_Operations v_Vector |}
- = {
- f_t_as_ntt:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K;
- f_seed_for_A:t_Array u8 (sz 32);
- f_A:t_Array (t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K) v_K
-}
-
-/// An unpacked ML-KEM IND-CCA Private Key
-type t_MlKemPrivateKeyUnpacked
- (v_K: usize) (v_Vector: Type0) {| i1: Libcrux_ml_kem.Vector.Traits.t_Operations v_Vector |}
- = {
- f_ind_cpa_private_key:t_IndCpaPrivateKeyUnpacked v_K v_Vector;
- f_implicit_rejection_value:t_Array u8 (sz 32)
-}
-
-/// An unpacked ML-KEM IND-CCA Private Key
-type t_MlKemPublicKeyUnpacked
- (v_K: usize) (v_Vector: Type0) {| i1: Libcrux_ml_kem.Vector.Traits.t_Operations v_Vector |}
- = {
- f_ind_cpa_public_key:t_IndCpaPublicKeyUnpacked v_K v_Vector;
- f_public_key_hash:t_Array u8 (sz 32)
-}
-
-/// An unpacked ML-KEM KeyPair
-type t_MlKemKeyPairUnpacked
- (v_K: usize) (v_Vector: Type0) {| i1: Libcrux_ml_kem.Vector.Traits.t_Operations v_Vector |}
- = {
- f_private_key:t_MlKemPrivateKeyUnpacked v_K v_Vector;
- f_public_key:t_MlKemPublicKeyUnpacked v_K v_Vector
-}
diff --git a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Avx2.Portable.fsti b/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Avx2.Portable.fsti
deleted file mode 100644
index fe64003c4..000000000
--- a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Avx2.Portable.fsti
+++ /dev/null
@@ -1,30 +0,0 @@
-module Libcrux_ml_kem.Vector.Avx2.Portable
-#set-options "--fuel 0 --ifuel 1 --z3rlimit 15"
-open Core
-open FStar.Mul
-
-val deserialize_11_int (bytes: t_Slice u8)
- : Prims.Pure (i16 & i16 & i16 & i16 & i16 & i16 & i16 & i16)
- Prims.l_True
- (fun _ -> Prims.l_True)
-
-val serialize_11_int (v: t_Slice i16)
- : Prims.Pure (u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8)
- Prims.l_True
- (fun _ -> Prims.l_True)
-
-type t_PortableVector = { f_elements:t_Array i16 (sz 16) }
-
-val from_i16_array (array: t_Array i16 (sz 16))
- : Prims.Pure t_PortableVector Prims.l_True (fun _ -> Prims.l_True)
-
-val serialize_11_ (v: t_PortableVector)
- : Prims.Pure (t_Array u8 (sz 22)) Prims.l_True (fun _ -> Prims.l_True)
-
-val to_i16_array (v: t_PortableVector)
- : Prims.Pure (t_Array i16 (sz 16)) Prims.l_True (fun _ -> Prims.l_True)
-
-val zero: Prims.unit -> Prims.Pure t_PortableVector Prims.l_True (fun _ -> Prims.l_True)
-
-val deserialize_11_ (bytes: t_Slice u8)
- : Prims.Pure t_PortableVector Prims.l_True (fun _ -> Prims.l_True)
diff --git a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Serialize.Edited.fsti b/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Serialize.Edited.fsti
deleted file mode 100644
index 4ed69770d..000000000
--- a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Serialize.Edited.fsti
+++ /dev/null
@@ -1,100 +0,0 @@
-module Libcrux_ml_kem.Vector.Portable.Serialize.Edited
-// #set-options "--fuel 0 --ifuel 1 --z3rlimit 15"
-// open Core
-// open FStar.Mul
-
-// val deserialize_10_int (bytes: t_Slice u8)
-// : Prims.Pure (i16 & i16 & i16 & i16 & i16 & i16 & i16 & i16)
-// Prims.l_True
-// (fun _ -> Prims.l_True)
-
-// val deserialize_11_int (bytes: t_Slice u8)
-// : Prims.Pure (i16 & i16 & i16 & i16 & i16 & i16 & i16 & i16)
-// Prims.l_True
-// (fun _ -> Prims.l_True)
-
-// val deserialize_12_int (bytes: t_Slice u8)
-// : Prims.Pure (i16 & i16) Prims.l_True (fun _ -> Prims.l_True)
-
-// val deserialize_4_int (bytes: t_Slice u8)
-// : Prims.Pure (i16 & i16 & i16 & i16 & i16 & i16 & i16 & i16)
-// Prims.l_True
-// (fun _ -> Prims.l_True)
-
-// val deserialize_5_int (bytes: t_Slice u8)
-// : Prims.Pure (i16 & i16 & i16 & i16 & i16 & i16 & i16 & i16)
-// Prims.l_True
-// (fun _ -> Prims.l_True)
-
-// val serialize_10_int (v: t_Slice i16)
-// : Prims.Pure (u8 & u8 & u8 & u8 & u8)
-// (requires (Core.Slice.impl__len #i16 v <: usize) =. sz 4)
-// (ensures
-// fun tuple ->
-// let tuple:(u8 & u8 & u8 & u8 & u8) = tuple in
-// BitVecEq.int_t_array_bitwise_eq' (v <: t_Array i16 (sz 4)) 10 (MkSeq.create5 tuple) 8)
-
-// val serialize_11_int (v: t_Slice i16)
-// : Prims.Pure (u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8)
-// (requires Seq.length v == 8 /\ (forall i. Rust_primitives.bounded (Seq.index v i) 11))
-// (ensures
-// fun tuple ->
-// let tuple:(u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8 & u8) = tuple in
-// BitVecEq.int_t_array_bitwise_eq' (v <: t_Array i16 (sz 8)) 11 (MkSeq.create11 tuple) 8)
-
-// val serialize_12_int (v: t_Slice i16)
-// : Prims.Pure (u8 & u8 & u8) Prims.l_True (fun _ -> Prims.l_True)
-
-// val serialize_4_int (v: t_Slice i16)
-// : Prims.Pure (u8 & u8 & u8 & u8) Prims.l_True (fun _ -> Prims.l_True)
-
-// val serialize_5_int (v: t_Slice i16)
-// : Prims.Pure (u8 & u8 & u8 & u8 & u8) Prims.l_True (fun _ -> Prims.l_True)
-
-// val serialize_1_ (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
-// : Prims.Pure (t_Array u8 (sz 2)) Prims.l_True (fun _ -> Prims.l_True)
-
-// val serialize_10_ (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
-// : Prims.Pure (t_Array u8 (sz 20)) Prims.l_True (fun _ -> Prims.l_True)
-
-// val serialize_11_ (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
-// : Prims.Pure (t_Array u8 (sz 22)) Prims.l_True (fun _ -> Prims.l_True)
-
-// val serialize_12_ (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
-// : Prims.Pure (t_Array u8 (sz 24)) Prims.l_True (fun _ -> Prims.l_True)
-
-// val serialize_4_ (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
-// : Prims.Pure (t_Array u8 (sz 8)) Prims.l_True (fun _ -> Prims.l_True)
-
-// val serialize_5_ (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
-// : Prims.Pure (t_Array u8 (sz 10)) Prims.l_True (fun _ -> Prims.l_True)
-
-// val deserialize_1_ (v: t_Slice u8)
-// : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector
-// Prims.l_True
-// (fun _ -> Prims.l_True)
-
-// val deserialize_10_ (bytes: t_Slice u8)
-// : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector
-// Prims.l_True
-// (fun _ -> Prims.l_True)
-
-// val deserialize_11_ (bytes: t_Slice u8)
-// : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector
-// Prims.l_True
-// (fun _ -> Prims.l_True)
-
-// val deserialize_12_ (bytes: t_Slice u8)
-// : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector
-// Prims.l_True
-// (fun _ -> Prims.l_True)
-
-// val deserialize_4_ (bytes: t_Slice u8)
-// : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector
-// Prims.l_True
-// (fun _ -> Prims.l_True)
-
-// val deserialize_5_ (bytes: t_Slice u8)
-// : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector
-// Prims.l_True
-// (fun _ -> Prims.l_True)
diff --git a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.fst b/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.fst
deleted file mode 100644
index 0ca12f7ff..000000000
--- a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.fst
+++ /dev/null
@@ -1,59 +0,0 @@
-module Libcrux_ml_kem.Vector.Portable
-#set-options "--fuel 0 --ifuel 1 --z3rlimit 100"
-open Core
-open FStar.Mul
-
-let _ =
- (* This module has implicit dependencies, here we make them explicit. *)
- (* The implicit dependencies arise from typeclasses instances. *)
- let open Libcrux_ml_kem.Vector.Portable.Vector_type in
- let open Libcrux_ml_kem.Vector.Traits in
- ()
-
-let deserialize_11_ (a: t_Slice u8) = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_11_ a
-
-let deserialize_5_ (a: t_Slice u8) = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_5_ a
-
-let serialize_11_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) =
- Libcrux_ml_kem.Vector.Portable.Serialize.serialize_11_ a
-
-let serialize_5_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) =
- Libcrux_ml_kem.Vector.Portable.Serialize.serialize_5_ a
-
-let deserialize_1_ (a: t_Slice u8) =
- let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_1_lemma a in
- let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_1_bounded_lemma a in
- Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_1_ a
-
-let deserialize_10_ (a: t_Slice u8) =
- let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_10_lemma a in
- let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_10_bounded_lemma a in
- Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_10_ a
-
-let deserialize_12_ (a: t_Slice u8) =
- let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_12_lemma a in
- let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_12_bounded_lemma a in
- Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_12_ a
-
-let deserialize_4_ (a: t_Slice u8) =
- let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_4_lemma a in
- let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_4_bounded_lemma a in
- Libcrux_ml_kem.Vector.Portable.Serialize.deserialize_4_ a
-
-let serialize_1_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) =
- let _:Prims.unit = assert (forall i. Rust_primitives.bounded (Seq.index a.f_elements i) 1) in
- let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.serialize_1_lemma a in
- Libcrux_ml_kem.Vector.Portable.Serialize.serialize_1_ a
-
-let serialize_10_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) =
- let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.serialize_10_lemma a in
- Libcrux_ml_kem.Vector.Portable.Serialize.serialize_10_ a
-
-let serialize_12_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) =
- let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.serialize_12_lemma a in
- Libcrux_ml_kem.Vector.Portable.Serialize.serialize_12_ a
-
-let serialize_4_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) =
- let _:Prims.unit = assert (forall i. Rust_primitives.bounded (Seq.index a.f_elements i) 4) in
- let _:Prims.unit = Libcrux_ml_kem.Vector.Portable.Serialize.serialize_4_lemma a in
- Libcrux_ml_kem.Vector.Portable.Serialize.serialize_4_ a
diff --git a/libcrux-ml-kem/proofs/fstar/spec/Makefile b/libcrux-ml-kem/proofs/fstar/spec/Makefile
deleted file mode 100644
index b4ce70a38..000000000
--- a/libcrux-ml-kem/proofs/fstar/spec/Makefile
+++ /dev/null
@@ -1 +0,0 @@
-include $(shell git rev-parse --show-toplevel)/fstar-helpers/Makefile.base
diff --git a/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.Instances.fst b/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.Instances.fst
deleted file mode 100644
index f598ee0ff..000000000
--- a/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.Instances.fst
+++ /dev/null
@@ -1,64 +0,0 @@
-module Spec.MLKEM.Instances
-#set-options "--fuel 0 --ifuel 1 --z3rlimit 30"
-open FStar.Mul
-open Core
-open Spec.Utils
-open Spec.MLKEM.Math
-open Spec.MLKEM
-
-
-(** MLKEM-768 Instantiation *)
-
-let mlkem768_rank : rank = sz 3
-
-#push-options "--z3rlimit 300"
-let mlkem768_generate_keypair (randomness:t_Array u8 (sz 64)):
- (t_Array u8 (sz 2400) & t_Array u8 (sz 1184)) & bool =
- ind_cca_generate_keypair mlkem768_rank randomness
-
-let mlkem768_encapsulate (public_key: t_Array u8 (sz 1184)) (randomness: t_Array u8 (sz 32)):
- (t_Array u8 (sz 1088) & t_Array u8 (sz 32)) & bool =
- ind_cca_encapsulate mlkem768_rank public_key randomness
-
-let mlkem768_decapsulate (secret_key: t_Array u8 (sz 2400)) (ciphertext: t_Array u8 (sz 1088)):
- t_Array u8 (sz 32) & bool =
- ind_cca_decapsulate mlkem768_rank secret_key ciphertext
-
-(** MLKEM-1024 Instantiation *)
-
-let mlkem1024_rank = sz 4
-
-let mlkem1024_generate_keypair (randomness:t_Array u8 (sz 64)):
- (t_Array u8 (sz 3168) & t_Array u8 (sz 1568)) & bool =
- ind_cca_generate_keypair mlkem1024_rank randomness
-
-#set-options "--z3rlimit 100"
-let mlkem1024_encapsulate (public_key: t_Array u8 (sz 1568)) (randomness: t_Array u8 (sz 32)):
- (t_Array u8 (sz 1568) & t_Array u8 (sz 32)) & bool =
- assert (v_CPA_CIPHERTEXT_SIZE mlkem1024_rank == sz 1568);
- ind_cca_encapsulate mlkem1024_rank public_key randomness
-
-let mlkem1024_decapsulate (secret_key: t_Array u8 (sz 3168)) (ciphertext: t_Array u8 (sz 1568)):
- t_Array u8 (sz 32) & bool =
- ind_cca_decapsulate mlkem1024_rank secret_key ciphertext
-
-(** MLKEM-512 Instantiation *)
-
-let mlkem512_rank : rank = sz 2
-
-let mlkem512_generate_keypair (randomness:t_Array u8 (sz 64)):
- (t_Array u8 (sz 1632) & t_Array u8 (sz 800)) & bool =
- ind_cca_generate_keypair mlkem512_rank randomness
-
-let mlkem512_encapsulate (public_key: t_Array u8 (sz 800)) (randomness: t_Array u8 (sz 32)):
- (t_Array u8 (sz 768) & t_Array u8 (sz 32)) & bool =
- assert (v_CPA_CIPHERTEXT_SIZE mlkem512_rank == sz 768);
- ind_cca_encapsulate mlkem512_rank public_key randomness
-
-
-let mlkem512_decapsulate (secret_key: t_Array u8 (sz 1632)) (ciphertext: t_Array u8 (sz 768)):
- t_Array u8 (sz 32) & bool =
- ind_cca_decapsulate mlkem512_rank secret_key ciphertext
-
-
-
diff --git a/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.Math.fst b/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.Math.fst
deleted file mode 100644
index 571e879fb..000000000
--- a/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.Math.fst
+++ /dev/null
@@ -1,293 +0,0 @@
-module Spec.MLKEM.Math
-#set-options "--fuel 0 --ifuel 1 --z3rlimit 80"
-
-open FStar.Mul
-open Core
-open Spec.Utils
-
-let v_FIELD_MODULUS: i32 = 3329l
-let is_rank (r:usize) = v r == 2 \/ v r == 3 \/ v r == 4
-
-type rank = r:usize{is_rank r}
-
-(** MLKEM Math and Sampling *)
-
-type field_element = n:nat{n < v v_FIELD_MODULUS}
-type polynomial = t_Array field_element (sz 256)
-type vector (r:rank) = t_Array polynomial r
-type matrix (r:rank) = t_Array (vector r) r
-
-val field_add: field_element -> field_element -> field_element
-let field_add a b = (a + b) % v v_FIELD_MODULUS
-
-val field_sub: field_element -> field_element -> field_element
-let field_sub a b = (a - b) % v v_FIELD_MODULUS
-
-val field_neg: field_element -> field_element
-let field_neg a = (0 - a) % v v_FIELD_MODULUS
-
-val field_mul: field_element -> field_element -> field_element
-let field_mul a b = (a * b) % v v_FIELD_MODULUS
-
-val poly_add: polynomial -> polynomial -> polynomial
-let poly_add a b = map2 field_add a b
-
-val poly_sub: polynomial -> polynomial -> polynomial
-let poly_sub a b = map2 field_sub a b
-
-let int_to_spec_fe (m:int) : field_element =
- let m_v = m % v v_FIELD_MODULUS in
- assert (m_v > - v v_FIELD_MODULUS);
- if m_v < 0 then
- m_v + v v_FIELD_MODULUS
- else m_v
-
-(* Convert concrete code types to spec types *)
-
-let to_spec_fe (m:i16) : field_element =
- int_to_spec_fe (v m)
-
-let to_spec_array #len (m:t_Array i16 len) : t_Array field_element len =
- createi #field_element len (fun i -> to_spec_fe (m.[i]))
-
-let to_spec_poly (m:t_Array i16 (sz 256)) : polynomial =
- to_spec_array m
-
-let to_spec_vector (#r:rank)
- (m:t_Array (t_Array i16 (sz 256)) r)
- : (vector r) =
- createi r (fun i -> to_spec_poly (m.[i]))
-
-let to_spec_matrix (#r:rank)
- (m:t_Array (t_Array (t_Array i16 (sz 256)) r) r)
- : (matrix r) =
- createi r (fun i -> to_spec_vector (m.[i]))
-
-(* Specifying NTT:
-bitrev7 = [int('{:07b}'.format(x)[::-1], 2) for x in range(0,128)]
-zetas = [pow(17,x) % 3329 for x in bitrev7]
-zetas_mont = [pow(2,16) * x % 3329 for x in zetas]
-zetas_mont_r = [(x - 3329 if x > 1664 else x) for x in zetas_mont]
-
-bitrev7 is
-[0, 64, 32, 96, 16, 80, 48, 112, 8, 72, 40, 104, 24, 88, 56, 120, 4, 68, 36, 100, 20, 84, 52, 116, 12, 76, 44, 108, 28, 92, 60, 124, 2, 66, 34, 98, 18, 82, 50, 114, 10, 74, 42, 106, 26, 90, 58, 122, 6, 70, 38, 102, 22, 86, 54, 118, 14, 78, 46, 110, 30, 94, 62, 126, 1, 65, 33, 97, 17, 81, 49, 113, 9, 73, 41, 105, 25, 89, 57, 121, 5, 69, 37, 101, 21, 85, 53, 117, 13, 77, 45, 109, 29, 93, 61, 125, 3, 67, 35, 99, 19, 83, 51, 115, 11, 75, 43, 107, 27, 91, 59, 123, 7, 71, 39, 103, 23, 87, 55, 119, 15, 79, 47, 111, 31, 95, 63, 127]
-
-zetas = 17^bitrev7 is
-[1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154]
-
-zetas_mont = zetas * 2^16 is
-[2285, 2571, 2970, 1812, 1493, 1422, 287, 202, 3158, 622, 1577, 182, 962, 2127, 1855, 1468, 573, 2004, 264, 383, 2500, 1458, 1727, 3199, 2648, 1017, 732, 608, 1787, 411, 3124, 1758, 1223, 652, 2777, 1015, 2036, 1491, 3047, 1785, 516, 3321, 3009, 2663, 1711, 2167, 126, 1469, 2476, 3239, 3058, 830, 107, 1908, 3082, 2378, 2931, 961, 1821, 2604, 448, 2264, 677, 2054, 2226, 430, 555, 843, 2078, 871, 1550, 105, 422, 587, 177, 3094, 3038, 2869, 1574, 1653, 3083, 778, 1159, 3182, 2552, 1483, 2727, 1119, 1739, 644, 2457, 349, 418, 329, 3173, 3254, 817, 1097, 603, 610, 1322, 2044, 1864, 384, 2114, 3193, 1218, 1994, 2455, 220, 2142, 1670, 2144, 1799, 2051, 794, 1819, 2475, 2459, 478, 3221, 3021, 996, 991, 958, 1869, 1522, 1628]
-
-zetas_mont_r = zetas_mont - 3329 if zetas_mont > 1664 else zetas_mont is
-[-1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, -308, 996, 991, 958, -1460, 1522, 1628]
-*)
-
-let zetas_list : list field_element = [1; 1729; 2580; 3289; 2642; 630; 1897; 848; 1062; 1919; 193; 797; 2786; 3260; 569; 1746; 296; 2447; 1339; 1476; 3046; 56; 2240; 1333; 1426; 2094; 535; 2882; 2393; 2879; 1974; 821; 289; 331; 3253; 1756; 1197; 2304; 2277; 2055; 650; 1977; 2513; 632; 2865; 33; 1320; 1915; 2319; 1435; 807; 452; 1438; 2868; 1534; 2402; 2647; 2617; 1481; 648; 2474; 3110; 1227; 910; 17; 2761; 583; 2649; 1637; 723; 2288; 1100; 1409; 2662; 3281; 233; 756; 2156; 3015; 3050; 1703; 1651; 2789; 1789; 1847; 952; 1461; 2687; 939; 2308; 2437; 2388; 733; 2337; 268; 641; 1584; 2298; 2037; 3220; 375; 2549; 2090; 1645; 1063; 319; 2773; 757; 2099; 561; 2466; 2594; 2804; 1092; 403; 1026; 1143; 2150; 2775; 886; 1722; 1212; 1874; 1029; 2110; 2935; 885; 2154]
-
-let zetas : t_Array field_element (sz 128) =
- assert_norm(List.Tot.length zetas_list == 128);
- Rust_primitives.Arrays.of_list zetas_list
-
-let poly_ntt_step (a:field_element) (b:field_element) (i:nat{i < 128}) =
- let t = field_mul b zetas.[sz i] in
- let b = field_sub a t in
- let a = field_add a t in
- (a,b)
-
-#push-options "--split_queries always"
-let poly_ntt_layer (p:polynomial) (l:nat{l > 0 /\ l < 8}) : polynomial =
- let len = pow2 l in
- let k = (128 / len) - 1 in
- Rust_primitives.Arrays.createi (sz 256) (fun i ->
- let round = v i / (2 * len) in
- let idx = v i % (2 * len) in
- let (idx0, idx1) = if idx < len then (idx, idx+len) else (idx-len,idx) in
- let (a_ntt, b_ntt) = poly_ntt_step p.[sz idx0] p.[sz idx1] (round + k) in
- if idx < len then a_ntt else b_ntt)
-#pop-options
-
-val poly_ntt: polynomial -> polynomial
-let poly_ntt p =
- let p = poly_ntt_layer p 7 in
- let p = poly_ntt_layer p 6 in
- let p = poly_ntt_layer p 5 in
- let p = poly_ntt_layer p 4 in
- let p = poly_ntt_layer p 3 in
- let p = poly_ntt_layer p 2 in
- let p = poly_ntt_layer p 1 in
- p
-
-let poly_inv_ntt_step (a:field_element) (b:field_element) (i:nat{i < 128}) =
- let b_minus_a = field_sub b a in
- let a = field_add a b in
- let b = field_mul b_minus_a zetas.[sz i] in
- (a,b)
-
-#push-options "--z3rlimit 150"
-let poly_inv_ntt_layer (p:polynomial) (l:nat{l > 0 /\ l < 8}) : polynomial =
- let len = pow2 l in
- let k = (256 / len) - 1 in
- Rust_primitives.Arrays.createi (sz 256) (fun i ->
- let round = v i / (2 * len) in
- let idx = v i % (2 * len) in
- let (idx0, idx1) = if idx < len then (idx, idx+len) else (idx-len,idx) in
- let (a_ntt, b_ntt) = poly_inv_ntt_step p.[sz idx0] p.[sz idx1] (k - round) in
- if idx < len then a_ntt else b_ntt)
-#pop-options
-
-val poly_inv_ntt: polynomial -> polynomial
-let poly_inv_ntt p =
- let p = poly_inv_ntt_layer p 1 in
- let p = poly_inv_ntt_layer p 2 in
- let p = poly_inv_ntt_layer p 3 in
- let p = poly_inv_ntt_layer p 4 in
- let p = poly_inv_ntt_layer p 5 in
- let p = poly_inv_ntt_layer p 6 in
- let p = poly_inv_ntt_layer p 7 in
- p
-
-let poly_base_case_multiply (a0 a1 b0 b1 zeta:field_element) =
- let c0 = field_add (field_mul a0 b0) (field_mul (field_mul a1 b1) zeta) in
- let c1 = field_add (field_mul a0 b1) (field_mul a1 b0) in
- (c0,c1)
-
-val poly_mul_ntt: polynomial -> polynomial -> polynomial
-let poly_mul_ntt a b =
- Rust_primitives.Arrays.createi (sz 256) (fun i ->
- let a0 = a.[sz (2 * (v i / 2))] in
- let a1 = a.[sz (2 * (v i / 2) + 1)] in
- let b0 = b.[sz (2 * (v i / 2))] in
- let b1 = b.[sz (2 * (v i / 2) + 1)] in
- let zeta_4 = zetas.[sz (64 + (v i/4))] in
- let zeta = if v i % 4 < 2 then zeta_4 else field_neg zeta_4 in
- let (c0,c1) = poly_base_case_multiply a0 a1 b0 b1 zeta in
- if v i % 2 = 0 then c0 else c1)
-
-
-val vector_add: #r:rank -> vector r -> vector r -> vector r
-let vector_add #p a b = map2 poly_add a b
-
-val vector_ntt: #r:rank -> vector r -> vector r
-let vector_ntt #p v = map_array poly_ntt v
-
-val vector_inv_ntt: #r:rank -> vector r -> vector r
-let vector_inv_ntt #p v = map_array poly_inv_ntt v
-
-val vector_mul_ntt: #r:rank -> vector r -> vector r -> vector r
-let vector_mul_ntt #p a b = map2 poly_mul_ntt a b
-
-val vector_sum: #r:rank -> vector r -> polynomial
-let vector_sum #r a = repeati (r -! sz 1)
- (fun i x -> assert (v i < v r - 1); poly_add x (a.[i +! sz 1])) a.[sz 0]
-
-val vector_dot_product_ntt: #r:rank -> vector r -> vector r -> polynomial
-let vector_dot_product_ntt a b = vector_sum (vector_mul_ntt a b)
-
-val matrix_transpose: #r:rank -> matrix r -> matrix r
-let matrix_transpose #r m =
- createi r (fun i ->
- createi r (fun j ->
- m.[j].[i]))
-
-val matrix_vector_mul_ntt: #r:rank -> matrix r -> vector r -> vector r
-let matrix_vector_mul_ntt #r m v =
- createi r (fun i -> vector_dot_product_ntt m.[i] v)
-
-val compute_As_plus_e_ntt: #r:rank -> a:matrix r -> s:vector r -> e:vector r -> vector r
-let compute_As_plus_e_ntt #p a s e = vector_add (matrix_vector_mul_ntt a s) e
-
-
-
-type dT = d: nat {d = 1 \/ d = 4 \/ d = 5 \/ d = 10 \/ d = 11 \/ d = 12}
-let max_d (d:dT) = if d < 12 then pow2 d else v v_FIELD_MODULUS
-type field_element_d (d:dT) = n:nat{n < max_d d}
-type polynomial_d (d:dT) = t_Array (field_element_d d) (sz 256)
-type vector_d (r:rank) (d:dT) = t_Array (polynomial_d d) r
-
-let bits_to_bytes (#bytes: usize) (bv: bit_vec (v bytes * 8))
- : Pure (t_Array u8 bytes)
- (requires True)
- (ensures fun r -> (forall i. bit_vec_of_int_t_array r 8 i == bv i))
- = bit_vec_to_int_t_array 8 bv
-
-let bytes_to_bits (#bytes: usize) (r: t_Array u8 bytes)
- : Pure (i: bit_vec (v bytes * 8))
- (requires True)
- (ensures fun f -> (forall i. bit_vec_of_int_t_array r 8 i == f i))
- = bit_vec_of_int_t_array r 8
-
-unfold let retype_bit_vector #a #b (#_:unit{a == b}) (x: a): b = x
-
-
-let compress_d (d: dT {d <> 12}) (x: field_element): field_element_d d
- = let r = (pow2 d * x + 1664) / v v_FIELD_MODULUS in
- assert (r * v v_FIELD_MODULUS <= pow2 d * x + 1664);
- assert (r * v v_FIELD_MODULUS <= pow2 d * (v v_FIELD_MODULUS - 1) + 1664);
- Math.Lemmas.lemma_div_le (r * v v_FIELD_MODULUS) (pow2 d * (v v_FIELD_MODULUS - 1) + 1664) (v v_FIELD_MODULUS);
- Math.Lemmas.cancel_mul_div r (v v_FIELD_MODULUS);
- assert (r <= (pow2 d * (v v_FIELD_MODULUS - 1) + 1664) / v v_FIELD_MODULUS);
- Math.Lemmas.lemma_div_mod_plus (1664 - pow2 d) (pow2 d) (v v_FIELD_MODULUS);
- assert (r <= pow2 d + (1664 - pow2 d) / v v_FIELD_MODULUS);
- assert (r <= pow2 d);
- if r = pow2 d then 0 else r
-
-let decompress_d (d: dT {d <> 12}) (x: field_element_d d): field_element
- = let r = (x * v v_FIELD_MODULUS + 1664) / pow2 d in
- r
-
-
-let byte_encode (d: dT) (coefficients: polynomial_d d): t_Array u8 (sz (32 * d))
- = let coefficients' : t_Array nat (sz 256) = map_array #(field_element_d d) (fun x -> x <: nat) coefficients in
- bits_to_bytes #(sz (32 * d))
- (retype_bit_vector (bit_vec_of_nat_array coefficients' d))
-
-let byte_decode (d: dT) (coefficients: t_Array u8 (sz (32 * d))): polynomial_d d
- = let bv = bytes_to_bits coefficients in
- let arr: t_Array nat (sz 256) = bit_vec_to_nat_array d (retype_bit_vector bv) in
- let p: polynomial_d d =
- createi (sz 256) (fun i ->
- let x_f : field_element = arr.[i] % v v_FIELD_MODULUS in
- assert (d < 12 ==> arr.[i] < pow2 d);
- let x_m : field_element_d d = x_f in
- x_m)
- in
- p
-
-let coerce_polynomial_12 (p:polynomial): polynomial_d 12 = p
-let coerce_vector_12 (#r:rank) (v:vector r): vector_d r 12 = v
-
-let compress_then_byte_encode (d: dT {d <> 12}) (coefficients: polynomial): t_Array u8 (sz (32 * d))
- = let coefs: t_Array (field_element_d d) (sz 256) = map_array (compress_d d) coefficients
- in
- byte_encode d coefs
-
-let byte_decode_then_decompress (d: dT {d <> 12}) (b:t_Array u8 (sz (32 * d))): polynomial
- = map_array (decompress_d d) (byte_decode d b)
-
-
-(**** Definitions to move or to rework *)
-let serialize_pre
- (d1: dT)
- (coefficients: t_Array i16 (sz 16))
- = forall i. i < 16 ==> bounded (Seq.index coefficients i) d1
-
-// TODO: this is an alternative version of byte_encode
-// rename to encoded bytes
-#push-options "--z3rlimit 80 --split_queries always"
-let serialize_post
- (d1: dT)
- (coefficients: t_Array i16 (sz 16) { serialize_pre d1 coefficients })
- (output: t_Array u8 (sz (d1 * 2)))
- = BitVecEq.int_t_array_bitwise_eq coefficients d1
- output 8
-
-// TODO: this is an alternative version of byte_decode
-// rename to decoded bytes
-let deserialize_post
- (d1: dT)
- (bytes: t_Array u8 (sz (d1 * 2)))
- (output: t_Array i16 (sz 16))
- = BitVecEq.int_t_array_bitwise_eq bytes 8
- output d1 /\
- forall (i:nat). i < 16 ==> bounded (Seq.index output i) d1
-#pop-options
diff --git a/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.fst b/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.fst
deleted file mode 100644
index 07c9216ae..000000000
--- a/libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.fst
+++ /dev/null
@@ -1,343 +0,0 @@
-module Spec.MLKEM
-#set-options "--fuel 0 --ifuel 1 --z3rlimit 80"
-open FStar.Mul
-open Core
-
-include Spec.Utils
-include Spec.MLKEM.Math
-
-(** ML-KEM Constants *)
-let v_BITS_PER_COEFFICIENT: usize = sz 12
-
-let v_COEFFICIENTS_IN_RING_ELEMENT: usize = sz 256
-
-let v_BITS_PER_RING_ELEMENT: usize = sz 3072 // v_COEFFICIENTS_IN_RING_ELEMENT *! sz 12
-
-let v_BYTES_PER_RING_ELEMENT: usize = sz 384 // v_BITS_PER_RING_ELEMENT /! sz 8
-
-let v_CPA_KEY_GENERATION_SEED_SIZE: usize = sz 32
-
-let v_H_DIGEST_SIZE: usize = sz 32
-// same as Libcrux.Digest.digest_size (Libcrux.Digest.Algorithm_Sha3_256_ <: Libcrux.Digest.t_Algorithm)
-
-let v_REJECTION_SAMPLING_SEED_SIZE: usize = sz 840 // sz 168 *! sz 5
-
-let v_SHARED_SECRET_SIZE: usize = v_H_DIGEST_SIZE
-
-val v_ETA1 (r:rank) : u:usize{u == sz 3 \/ u == sz 2}
-let v_ETA1 (r:rank) : usize =
- if r = sz 2 then sz 3 else
- if r = sz 3 then sz 2 else
- if r = sz 4 then sz 2
-
-let v_ETA2 (r:rank) : usize = sz 2
-
-val v_VECTOR_U_COMPRESSION_FACTOR (r:rank) : u:usize{u == sz 10 \/ u == sz 11}
-let v_VECTOR_U_COMPRESSION_FACTOR (r:rank) : usize =
- if r = sz 2 then sz 10 else
- if r = sz 3 then sz 10 else
- if r = sz 4 then sz 11
-
-val v_VECTOR_V_COMPRESSION_FACTOR (r:rank) : u:usize{u == sz 4 \/ u == sz 5}
-let v_VECTOR_V_COMPRESSION_FACTOR (r:rank) : usize =
- if r = sz 2 then sz 4 else
- if r = sz 3 then sz 4 else
- if r = sz 4 then sz 5
-
-val v_ETA1_RANDOMNESS_SIZE (r:rank) : u:usize{u == sz 128 \/ u == sz 192}
-let v_ETA1_RANDOMNESS_SIZE (r:rank) = v_ETA1 r *! sz 64
-
-val v_ETA2_RANDOMNESS_SIZE (r:rank) : u:usize{u == sz 128}
-let v_ETA2_RANDOMNESS_SIZE (r:rank) = v_ETA2 r *! sz 64
-
-val v_RANKED_BYTES_PER_RING_ELEMENT (r:rank) : u:usize{u = sz 768 \/ u = sz 1152 \/ u = sz 1536}
-let v_RANKED_BYTES_PER_RING_ELEMENT (r:rank) = r *! v_BYTES_PER_RING_ELEMENT
-
-let v_T_AS_NTT_ENCODED_SIZE (r:rank) = v_RANKED_BYTES_PER_RING_ELEMENT r
-let v_CPA_PRIVATE_KEY_SIZE (r:rank) = v_RANKED_BYTES_PER_RING_ELEMENT r
-
-val v_CPA_PUBLIC_KEY_SIZE (r:rank) : u:usize{u = sz 800 \/ u = sz 1184 \/ u = sz 1568}
-let v_CPA_PUBLIC_KEY_SIZE (r:rank) = v_RANKED_BYTES_PER_RING_ELEMENT r +! sz 32
-
-val v_CCA_PRIVATE_KEY_SIZE (r:rank) : u:usize{u = sz 1632 \/ u = sz 2400 \/ u = sz 3168}
-let v_CCA_PRIVATE_KEY_SIZE (r:rank) =
- (v_CPA_PRIVATE_KEY_SIZE r +! v_CPA_PUBLIC_KEY_SIZE r +! v_H_DIGEST_SIZE +! v_SHARED_SECRET_SIZE)
-
-let v_CCA_PUBLIC_KEY_SIZE (r:rank) = v_CPA_PUBLIC_KEY_SIZE r
-
-val v_C1_BLOCK_SIZE (r:rank): u:usize{(u = sz 320 \/ u = sz 352) /\ v u == 32 * v (v_VECTOR_U_COMPRESSION_FACTOR r)}
-let v_C1_BLOCK_SIZE (r:rank) = sz 32 *! v_VECTOR_U_COMPRESSION_FACTOR r
-
-val v_C1_SIZE (r:rank) : u:usize{(u >=. sz 640 /\ u <=. sz 1448) /\
- v u == v (v_C1_BLOCK_SIZE r) * v r}
-let v_C1_SIZE (r:rank) = v_C1_BLOCK_SIZE r *! r
-
-val v_C2_SIZE (r:rank) : u:usize{(u = sz 128 \/ u = sz 160) /\ v u == 32 * v (v_VECTOR_V_COMPRESSION_FACTOR r)}
-let v_C2_SIZE (r:rank) = sz 32 *! v_VECTOR_V_COMPRESSION_FACTOR r
-
-val v_CPA_CIPHERTEXT_SIZE (r:rank) : u:usize {v u = v (v_C1_SIZE r) + v (v_C2_SIZE r)}
-let v_CPA_CIPHERTEXT_SIZE (r:rank) = v_C1_SIZE r +! v_C2_SIZE r
-
-let v_CCA_CIPHERTEXT_SIZE (r:rank) = v_CPA_CIPHERTEXT_SIZE r
-
-val v_IMPLICIT_REJECTION_HASH_INPUT_SIZE (r:rank): u:usize{v u == v v_SHARED_SECRET_SIZE +
- v (v_CPA_CIPHERTEXT_SIZE r)}
-let v_IMPLICIT_REJECTION_HASH_INPUT_SIZE (r:rank) =
- v_SHARED_SECRET_SIZE +! v_CPA_CIPHERTEXT_SIZE r
-
-val v_KEY_GENERATION_SEED_SIZE: u:usize{u = sz 64}
-let v_KEY_GENERATION_SEED_SIZE: usize =
- v_CPA_KEY_GENERATION_SEED_SIZE +!
- v_SHARED_SECRET_SIZE
-
-
-(** ML-KEM Types *)
-
-type t_MLKEMPublicKey (r:rank) = t_Array u8 (v_CPA_PUBLIC_KEY_SIZE r)
-type t_MLKEMPrivateKey (r:rank) = t_Array u8 (v_CCA_PRIVATE_KEY_SIZE r)
-type t_MLKEMKeyPair (r:rank) = t_MLKEMPrivateKey r & t_MLKEMPublicKey r
-
-type t_MLKEMCPAPrivateKey (r:rank) = t_Array u8 (v_CPA_PRIVATE_KEY_SIZE r)
-type t_MLKEMCPAKeyPair (r:rank) = t_MLKEMCPAPrivateKey r & t_MLKEMPublicKey r
-
-type t_MLKEMCiphertext (r:rank) = t_Array u8 (v_CPA_CIPHERTEXT_SIZE r)
-type t_MLKEMSharedSecret = t_Array u8 (v_SHARED_SECRET_SIZE)
-
-
-assume val sample_max: n:usize{v n < pow2 32 /\ v n >= 128 * 3 /\ v n % 3 = 0}
-
-val sample_polynomial_ntt: seed:t_Array u8 (sz 34) -> (polynomial & bool)
-let sample_polynomial_ntt seed =
- let randomness = v_XOF sample_max seed in
- let bv = bytes_to_bits randomness in
- assert (v sample_max * 8 == (((v sample_max / 3) * 2) * 12));
- let bv: bit_vec ((v (sz ((v sample_max / 3) * 2))) * 12) = retype_bit_vector bv in
- let i16s = bit_vec_to_nat_array #(sz ((v sample_max / 3) * 2)) 12 bv in
- assert ((v sample_max / 3) * 2 >= 256);
- let poly0: polynomial = Seq.create 256 0 in
- let index_t = n:nat{n <= 256} in
- let (sampled, poly1) =
- repeati #(index_t & polynomial) (sz ((v sample_max / 3) * 2))
- (fun i (sampled,acc) ->
- if sampled < 256 then
- let sample = Seq.index i16s (v i) in
- if sample < 3329 then
- (sampled+1, Rust_primitives.Hax.update_at acc (sz sampled) sample)
- else (sampled, acc)
- else (sampled, acc))
- (0,poly0) in
- if sampled < 256 then poly0, false else poly1, true
-
-let sample_polynomial_ntt_at_index (seed:t_Array u8 (sz 32)) (i j: (x:usize{v x <= 4})) : polynomial & bool =
- let seed34 = Seq.append seed (Seq.create 2 0uy) in
- let seed34 = Rust_primitives.Hax.update_at seed34 (sz 32) (mk_int #u8_inttype (v i)) in
- let seed34 = Rust_primitives.Hax.update_at seed34 (sz 33) (mk_int #u8_inttype (v j)) in
- sample_polynomial_ntt seed34
-
-val sample_matrix_A_ntt: #r:rank -> seed:t_Array u8 (sz 32) -> (matrix r & bool)
-let sample_matrix_A_ntt #r seed =
- let m =
- createi r (fun i ->
- createi r (fun j ->
- let (p,b) = sample_polynomial_ntt_at_index seed i j in
- p))
- in
- let sufficient_randomness =
- repeati r (fun i b ->
- repeati r (fun j b ->
- let (p,v) = sample_polynomial_ntt_at_index seed i j in
- b && v) b) true in
- (m, sufficient_randomness)
-
-assume val sample_poly_cbd: v_ETA:usize{v v_ETA == 2 \/ v v_ETA == 3} -> t_Array u8 (v_ETA *! sz 64) -> polynomial
-
-open Rust_primitives.Integers
-
-val sample_poly_cbd2: #r:rank -> seed:t_Array u8 (sz 32) -> domain_sep:usize{v domain_sep < 256} -> polynomial
-let sample_poly_cbd2 #r seed domain_sep =
- let prf_input = Seq.append seed (Seq.create 1 (mk_int #u8_inttype (v domain_sep))) in
- let prf_output = v_PRF (v_ETA2_RANDOMNESS_SIZE r) prf_input in
- sample_poly_cbd (v_ETA2 r) prf_output
-
-val sample_poly_cbd1: #r:rank -> seed:t_Array u8 (sz 32) -> domain_sep:usize{v domain_sep < 256} -> polynomial
-let sample_poly_cbd1 #r seed domain_sep =
- let prf_input = Seq.append seed (Seq.create 1 (mk_int #u8_inttype (v domain_sep))) in
- let prf_output = v_PRF (v_ETA1_RANDOMNESS_SIZE r) prf_input in
- sample_poly_cbd (v_ETA1 r) prf_output
-
-let sample_vector_cbd1 (#r:rank) (seed:t_Array u8 (sz 32)) (domain_sep:usize{v domain_sep < 2 * v r}) : vector r =
- createi r (fun i -> sample_poly_cbd1 #r seed (domain_sep +! i))
-
-let sample_vector_cbd2 (#r:rank) (seed:t_Array u8 (sz 32)) (domain_sep:usize{v domain_sep < 2 * v r}) : vector r =
- createi r (fun i -> sample_poly_cbd2 #r seed (domain_sep +! i))
-
-let sample_vector_cbd_then_ntt (#r:rank) (seed:t_Array u8 (sz 32)) (domain_sep:usize{v domain_sep < 2 * v r}) : vector r =
- vector_ntt (sample_vector_cbd1 #r seed domain_sep)
-
-let vector_encode_12 (#r:rank) (v: vector r) : t_Array u8 (v_T_AS_NTT_ENCODED_SIZE r)
- = let s: t_Array (t_Array _ (sz 384)) r = map_array (byte_encode 12) (coerce_vector_12 v) in
- flatten s
-
-let vector_decode_12 (#r:rank) (arr: t_Array u8 (v_T_AS_NTT_ENCODED_SIZE r)): vector r
- = createi r (fun block ->
- let block_size = (sz (32 * 12)) in
- let slice = Seq.slice arr (v block * v block_size)
- (v block * v block_size + v block_size) in
- byte_decode 12 slice
- )
-
-let compress_then_encode_message (p:polynomial) : t_Array u8 v_SHARED_SECRET_SIZE
- = compress_then_byte_encode 1 p
-
-let decode_then_decompress_message (b:t_Array u8 v_SHARED_SECRET_SIZE): polynomial
- = byte_decode_then_decompress 1 b
-
-let compress_then_encode_u (#r:rank) (vec: vector r): t_Array u8 (v_C1_SIZE r)
- = let d = v (v_VECTOR_U_COMPRESSION_FACTOR r) in
- flatten (map_array (compress_then_byte_encode d) vec)
-
-let decode_then_decompress_u (#r:rank) (arr: t_Array u8 (v_C1_SIZE r)): vector r
- = let d = v_VECTOR_U_COMPRESSION_FACTOR r in
- createi r (fun block ->
- let block_size = v_C1_BLOCK_SIZE r in
- let slice = Seq.slice arr (v block * v block_size)
- (v block * v block_size + v block_size) in
- byte_decode_then_decompress (v d) slice
- )
-
-let compress_then_encode_v (#r:rank): polynomial -> t_Array u8 (v_C2_SIZE r)
- = compress_then_byte_encode (v (v_VECTOR_V_COMPRESSION_FACTOR r))
-
-let decode_then_decompress_v (#r:rank): t_Array u8 (v_C2_SIZE r) -> polynomial
- = byte_decode_then_decompress (v (v_VECTOR_V_COMPRESSION_FACTOR r))
-
-(** IND-CPA Functions *)
-
-/// This function implements most of Algorithm 12 of the
-/// NIST FIPS 203 specification; this is the MLKEM CPA-PKE key generation algorithm.
-///
-/// We say "most of" since Algorithm 12 samples the required randomness within
-/// the function itself, whereas this implementation expects it to be provided
-/// through the `key_generation_seed` parameter.
-
-val ind_cpa_generate_keypair (r:rank) (randomness:t_Array u8 v_CPA_KEY_GENERATION_SEED_SIZE) :
- (t_MLKEMCPAKeyPair r & bool)
-let ind_cpa_generate_keypair r randomness =
- let hashed = v_G randomness in
- let (seed_for_A, seed_for_secret_and_error) = split hashed (sz 32) in
- let (matrix_A_as_ntt, sufficient_randomness) = sample_matrix_A_ntt #r seed_for_A in
- let secret_as_ntt = sample_vector_cbd_then_ntt #r seed_for_secret_and_error (sz 0) in
- let error_as_ntt = sample_vector_cbd_then_ntt #r seed_for_secret_and_error r in
- let t_as_ntt = compute_As_plus_e_ntt #r matrix_A_as_ntt secret_as_ntt error_as_ntt in
- let public_key_serialized = Seq.append (vector_encode_12 #r t_as_ntt) seed_for_A in
- let secret_key_serialized = vector_encode_12 #r secret_as_ntt in
- ((secret_key_serialized,public_key_serialized), sufficient_randomness)
-
-/// This function implements Algorithm 13 of the
-/// NIST FIPS 203 specification; this is the MLKEM CPA-PKE encryption algorithm.
-
-val ind_cpa_encrypt (r:rank) (public_key: t_MLKEMPublicKey r)
- (message: t_Array u8 v_SHARED_SECRET_SIZE)
- (randomness:t_Array u8 v_SHARED_SECRET_SIZE) :
- (t_MLKEMCiphertext r & bool)
-
-[@ "opaque_to_smt"]
-let ind_cpa_encrypt r public_key message randomness =
- let (t_as_ntt_bytes, seed_for_A) = split public_key (v_T_AS_NTT_ENCODED_SIZE r) in
- let t_as_ntt = vector_decode_12 #r t_as_ntt_bytes in
- let matrix_A_as_ntt, sufficient_randomness = sample_matrix_A_ntt #r seed_for_A in
- let r_as_ntt = sample_vector_cbd_then_ntt #r randomness (sz 0) in
- let error_1 = sample_vector_cbd2 #r randomness r in
- let error_2 = sample_poly_cbd2 #r randomness (r +! r) in
- let u = vector_add (vector_inv_ntt (matrix_vector_mul_ntt (matrix_transpose matrix_A_as_ntt) r_as_ntt)) error_1 in
- let mu = decode_then_decompress_message message in
- let v = poly_add (poly_add (vector_dot_product_ntt t_as_ntt r_as_ntt) error_2) mu in
- let c1 = compress_then_encode_u #r u in
- let c2 = compress_then_encode_v #r v in
- (concat c1 c2, sufficient_randomness)
-
-/// This function implements Algorithm 14 of the
-/// NIST FIPS 203 specification; this is the MLKEM CPA-PKE decryption algorithm.
-
-val ind_cpa_decrypt (r:rank) (secret_key: t_MLKEMCPAPrivateKey r)
- (ciphertext: t_MLKEMCiphertext r):
- t_MLKEMSharedSecret
-
-[@ "opaque_to_smt"]
-let ind_cpa_decrypt r secret_key ciphertext =
- let (c1,c2) = split ciphertext (v_C1_SIZE r) in
- let u = decode_then_decompress_u #r c1 in
- let v = decode_then_decompress_v #r c2 in
- let secret_as_ntt = vector_decode_12 #r secret_key in
- let w = poly_sub v (poly_inv_ntt (vector_dot_product_ntt secret_as_ntt (vector_ntt u))) in
- compress_then_encode_message w
-
-(** IND-CCA Functions *)
-
-
-/// This function implements most of Algorithm 15 of the
-/// NIST FIPS 203 specification; this is the MLKEM CCA-KEM key generation algorithm.
-///
-/// We say "most of" since Algorithm 15 samples the required randomness within
-/// the function itself, whereas this implementation expects it to be provided
-/// through the `randomness` parameter.
-///
-/// TODO: input validation
-
-val ind_cca_generate_keypair (r:rank) (randomness:t_Array u8 v_KEY_GENERATION_SEED_SIZE) :
- t_MLKEMKeyPair r & bool
-let ind_cca_generate_keypair p randomness =
- let (ind_cpa_keypair_randomness, implicit_rejection_value) =
- split randomness v_CPA_KEY_GENERATION_SEED_SIZE in
-
- let (ind_cpa_secret_key,ind_cpa_public_key), sufficient_randomness = ind_cpa_generate_keypair p ind_cpa_keypair_randomness in
- let ind_cca_secret_key = Seq.append ind_cpa_secret_key (
- Seq.append ind_cpa_public_key (
- Seq.append (v_H ind_cpa_public_key) implicit_rejection_value)) in
- (ind_cca_secret_key, ind_cpa_public_key), sufficient_randomness
-
-/// This function implements most of Algorithm 16 of the
-/// NIST FIPS 203 specification; this is the MLKEM CCA-KEM encapsulation algorithm.
-///
-/// We say "most of" since Algorithm 16 samples the required randomness within
-/// the function itself, whereas this implementation expects it to be provided
-/// through the `randomness` parameter.
-///
-/// TODO: input validation
-
-val ind_cca_encapsulate (r:rank) (public_key: t_MLKEMPublicKey r)
- (randomness:t_Array u8 v_SHARED_SECRET_SIZE) :
- (t_MLKEMCiphertext r & t_MLKEMSharedSecret) & bool
-let ind_cca_encapsulate p public_key randomness =
- let to_hash = concat randomness (v_H public_key) in
- let hashed = v_G to_hash in
- let (shared_secret, pseudorandomness) = split hashed v_SHARED_SECRET_SIZE in
- let ciphertext, sufficient_randomness = ind_cpa_encrypt p public_key randomness pseudorandomness in
- (ciphertext,shared_secret), sufficient_randomness
-
-
-/// This function implements Algorithm 17 of the
-/// NIST FIPS 203 specification; this is the MLKEM CCA-KEM encapsulation algorithm.
-
-val ind_cca_decapsulate (r:rank) (secret_key: t_MLKEMPrivateKey r)
- (ciphertext: t_MLKEMCiphertext r):
- t_MLKEMSharedSecret & bool
-let ind_cca_decapsulate p secret_key ciphertext =
- let (ind_cpa_secret_key,rest) = split secret_key (v_CPA_PRIVATE_KEY_SIZE p) in
- let (ind_cpa_public_key,rest) = split rest (v_CPA_PUBLIC_KEY_SIZE p) in
- let (ind_cpa_public_key_hash,implicit_rejection_value) = split rest v_H_DIGEST_SIZE in
-
- let decrypted = ind_cpa_decrypt p ind_cpa_secret_key ciphertext in
- let to_hash = concat decrypted ind_cpa_public_key_hash in
- let hashed = v_G to_hash in
- let (success_shared_secret, pseudorandomness) = split hashed v_SHARED_SECRET_SIZE in
-
- assert (Seq.length implicit_rejection_value = 32);
- let to_hash = concat implicit_rejection_value ciphertext in
- let rejection_shared_secret = v_J to_hash in
-
- let reencrypted, sufficient_randomness = ind_cpa_encrypt p ind_cpa_public_key decrypted pseudorandomness in
- if reencrypted = ciphertext
- then success_shared_secret, sufficient_randomness
- else rejection_shared_secret, sufficient_randomness
-
diff --git a/libcrux-ml-kem/proofs/fstar/spec/Spec.Utils.fst b/libcrux-ml-kem/proofs/fstar/spec/Spec.Utils.fst
deleted file mode 100644
index 1c6ed14b1..000000000
--- a/libcrux-ml-kem/proofs/fstar/spec/Spec.Utils.fst
+++ /dev/null
@@ -1,493 +0,0 @@
-module Spec.Utils
-#set-options "--fuel 0 --ifuel 1 --z3rlimit 100"
-open FStar.Mul
-open Core
-
-(** Utils *)
-let map_slice #a #b
- (f:(x:a -> b))
- (s: t_Slice a): t_Slice b
- = createi (length s) (fun i -> f (Seq.index s (v i)))
-
-let map_array #a #b #len
- (f:(x:a -> b))
- (s: t_Array a len): t_Array b len
- = createi (length s) (fun i -> f (Seq.index s (v i)))
-
-let map2 #a #b #c #len
- (f:a -> b -> c)
- (x: t_Array a len) (y: t_Array b len): t_Array c len
- = createi (length x) (fun i -> f (Seq.index x (v i)) (Seq.index y (v i)))
-
-let create len c = createi len (fun i -> c)
-
-let repeati #acc (l:usize) (f:(i:usize{v i < v l}) -> acc -> acc) acc0 : acc = Lib.LoopCombinators.repeati (v l) (fun i acc -> f (sz i) acc) acc0
-
-let createL len l = Rust_primitives.Hax.array_of_list len l
-
-let create16 v15 v14 v13 v12 v11 v10 v9 v8 v7 v6 v5 v4 v3 v2 v1 v0 =
- let l = [v15; v14; v13; v12; v11; v10; v9; v8; v7; v6; v5; v4; v3; v2; v1; v0] in
- assert_norm (List.Tot.length l == 16);
- createL 16 l
-
-
-val lemma_createL_index #a len l i :
- Lemma (Seq.index (createL #a len l) i == List.Tot.index l i)
- [SMTPat (Seq.index (createL #a len l) i)]
-let lemma_createL_index #a len l i = ()
-
-val lemma_create16_index #a v15 v14 v13 v12 v11 v10 v9 v8 v7 v6 v5 v4 v3 v2 v1 v0 i :
- Lemma (Seq.index (create16 #a v15 v14 v13 v12 v11 v10 v9 v8 v7 v6 v5 v4 v3 v2 v1 v0) i ==
- (if i = 0 then v15 else
- if i = 1 then v14 else
- if i = 2 then v13 else
- if i = 3 then v12 else
- if i = 4 then v11 else
- if i = 5 then v10 else
- if i = 6 then v9 else
- if i = 7 then v8 else
- if i = 8 then v7 else
- if i = 9 then v6 else
- if i = 10 then v5 else
- if i = 11 then v4 else
- if i = 12 then v3 else
- if i = 13 then v2 else
- if i = 14 then v1 else
- if i = 15 then v0))
- [SMTPat (Seq.index (create16 #a v15 v14 v13 v12 v11 v10 v9 v8 v7 v6 v5 v4 v3 v2 v1 v0) i)]
-let lemma_create16_index #a v15 v14 v13 v12 v11 v10 v9 v8 v7 v6 v5 v4 v3 v2 v1 v0 i =
- let l = [v15; v14; v13; v12; v11; v10; v9; v8; v7; v6; v5; v4; v3; v2; v1; v0] in
- assert_norm (List.Tot.index l 0 == v15);
- assert_norm (List.Tot.index l 1 == v14);
- assert_norm (List.Tot.index l 2 == v13);
- assert_norm (List.Tot.index l 3 == v12);
- assert_norm (List.Tot.index l 4 == v11);
- assert_norm (List.Tot.index l 5 == v10);
- assert_norm (List.Tot.index l 6 == v9);
- assert_norm (List.Tot.index l 7 == v8);
- assert_norm (List.Tot.index l 8 == v7);
- assert_norm (List.Tot.index l 9 == v6);
- assert_norm (List.Tot.index l 10 == v5);
- assert_norm (List.Tot.index l 11 == v4);
- assert_norm (List.Tot.index l 12 == v3);
- assert_norm (List.Tot.index l 13 == v2);
- assert_norm (List.Tot.index l 14 == v1);
- assert_norm (List.Tot.index l 15 == v0)
-
-
-val lemma_createi_index #a len f i :
- Lemma (Seq.index (createi #a len f) i == f (sz i))
- [SMTPat (Seq.index (createi #a len f) i)]
-let lemma_createi_index #a len f i = ()
-
-val lemma_create_index #a len c i:
- Lemma (Seq.index (create #a len c) i == c)
- [SMTPat (Seq.index (create #a len c) i)]
-let lemma_create_index #a len c i = ()
-
-val lemma_map_index #a #b #len f x i:
- Lemma (Seq.index (map_array #a #b #len f x) i == f (Seq.index x i))
- [SMTPat (Seq.index (map_array #a #b #len f x) i)]
-let lemma_map_index #a #b #len f x i = ()
-
-val lemma_map2_index #a #b #c #len f x y i:
- Lemma (Seq.index (map2 #a #b #c #len f x y) i == f (Seq.index x i) (Seq.index y i))
- [SMTPat (Seq.index (map2 #a #b #c #len f x y) i)]
-let lemma_map2_index #a #b #c #len f x y i = ()
-
-let lemma_bitand_properties #t (x:int_t t) :
- Lemma ((x &. ones) == x /\ (x &. mk_int #t 0) == mk_int #t 0 /\ (ones #t &. x) == x /\ (mk_int #t 0 &. x) == mk_int #t 0) =
- logand_lemma #t x x
-
-#push-options "--z3rlimit 250"
-let flatten #t #n
- (#m: usize {range (v n * v m) usize_inttype})
- (x: t_Array (t_Array t m) n)
- : t_Array t (m *! n)
- = createi (m *! n) (fun i -> Seq.index (Seq.index x (v i / v m)) (v i % v m))
-#pop-options
-
-type t_Error = | Error_RejectionSampling : t_Error
-
-type t_Result a b =
- | Ok: a -> t_Result a b
- | Err: b -> t_Result a b
-
-(** Hash Function *)
-open Spec.SHA3
-
-val v_G (input: t_Slice u8) : t_Array u8 (sz 64)
-let v_G input = map_slice Lib.RawIntTypes.u8_to_UInt8 (sha3_512 (Seq.length input) (map_slice Lib.IntTypes.secret input))
-
-val v_H (input: t_Slice u8) : t_Array u8 (sz 32)
-let v_H input = map_slice Lib.RawIntTypes.u8_to_UInt8 (sha3_256 (Seq.length input) (map_slice Lib.IntTypes.secret input))
-
-val v_PRF (v_LEN: usize{v v_LEN < pow2 32}) (input: t_Slice u8) : t_Array u8 v_LEN
-let v_PRF v_LEN input = map_slice Lib.RawIntTypes.u8_to_UInt8 (
- shake256 (Seq.length input) (map_slice Lib.IntTypes.secret input) (v v_LEN))
-
-let v_J (input: t_Slice u8) : t_Array u8 (sz 32) = v_PRF (sz 32) input
-
-val v_XOF (v_LEN: usize{v v_LEN < pow2 32}) (input: t_Slice u8) : t_Array u8 v_LEN
-let v_XOF v_LEN input = map_slice Lib.RawIntTypes.u8_to_UInt8 (
- shake128 (Seq.length input) (map_slice Lib.IntTypes.secret input) (v v_LEN))
-
-let update_at_range_lemma #n
- (s: t_Slice 't)
- (i: Core.Ops.Range.t_Range (int_t n) {(Core.Ops.Range.impl_index_range_slice 't n).f_index_pre s i})
- (x: t_Slice 't)
- : Lemma
- (requires (Seq.length x == v i.f_end - v i.f_start))
- (ensures (
- let s' = Rust_primitives.Hax.Monomorphized_update_at.update_at_range s i x in
- let len = v i.f_start in
- forall (i: nat). i < len ==> Seq.index s i == Seq.index s' i
- ))
- [SMTPat (Rust_primitives.Hax.Monomorphized_update_at.update_at_range s i x)]
- = let s' = Rust_primitives.Hax.Monomorphized_update_at.update_at_range s i x in
- let len = v i.f_start in
- introduce forall (i:nat {i < len}). Seq.index s i == Seq.index s' i
- with (assert ( Seq.index (Seq.slice s 0 len) i == Seq.index s i
- /\ Seq.index (Seq.slice s' 0 len) i == Seq.index s' i ))
-
-
-/// Bounded integers
-
-let is_intb (l:nat) (x:int) = (x <= l) && (x >= -l)
-let is_i16b (l:nat) (x:i16) = is_intb l (v x)
-let is_i16b_array (l:nat) (x:t_Slice i16) = forall i. i < Seq.length x ==> is_i16b l (Seq.index x i)
-let is_i16b_vector (l:nat) (r:usize) (x:t_Array (t_Array i16 (sz 256)) r) = forall i. i < v r ==> is_i16b_array l (Seq.index x i)
-let is_i16b_matrix (l:nat) (r:usize) (x:t_Array (t_Array (t_Array i16 (sz 256)) r) r) = forall i. i < v r ==> is_i16b_vector l r (Seq.index x i)
-
-[@ "opaque_to_smt"]
-let is_i16b_array_opaque (l:nat) (x:t_Slice i16) = is_i16b_array l x
-
-let is_i32b (l:nat) (x:i32) = is_intb l (v x)
-let is_i32b_array (l:nat) (x:t_Slice i32) = forall i. i < Seq.length x ==> is_i32b l (Seq.index x i)
-
-let nat_div_ceil (x:nat) (y:pos) : nat = if (x % y = 0) then x/y else (x/y)+1
-
-val lemma_intb_le b b'
- : Lemma (requires (b <= b'))
- (ensures (forall n. is_intb b n ==> is_intb b' n))
-let lemma_intb_le b b' = ()
-
-#push-options "--z3rlimit 200"
-val lemma_mul_intb (b1 b2: nat) (n1 n2: int)
- : Lemma (requires (is_intb b1 n1 /\ is_intb b2 n2))
- (ensures (is_intb (b1 * b2) (n1 * n2)))
-let lemma_mul_intb (b1 b2: nat) (n1 n2: int) =
- if n1 = 0 || n2 = 0
- then ()
- else
- let open FStar.Math.Lemmas in
- lemma_abs_bound n1 b1;
- lemma_abs_bound n2 b2;
- lemma_abs_mul n1 n2;
- lemma_mult_le_left (abs n1) (abs n2) b2;
- lemma_mult_le_right b2 (abs n1) b1;
- lemma_abs_bound (n1 * n2) (b1 * b2)
-#pop-options
-
-#push-options "--z3rlimit 200"
-val lemma_mul_i16b (b1 b2: nat) (n1 n2: i16)
- : Lemma (requires (is_i16b b1 n1 /\ is_i16b b2 n2 /\ b1 * b2 < pow2 31))
- (ensures (range (v n1 * v n2) i32_inttype /\
- is_i32b (b1 * b2) ((cast n1 <: i32) *! (cast n2 <: i32)) /\
- v ((cast n1 <: i32) *! (cast n2 <: i32)) == v n1 * v n2))
-
-let lemma_mul_i16b (b1 b2: nat) (n1 n2: i16) =
- if v n1 = 0 || v n2 = 0
- then ()
- else
- let open FStar.Math.Lemmas in
- lemma_abs_bound (v n1) b1;
- lemma_abs_bound (v n2) b2;
- lemma_abs_mul (v n1) (v n2);
- lemma_mult_le_left (abs (v n1)) (abs (v n2)) b2;
- lemma_mult_le_right b2 (abs (v n1)) b1;
- lemma_abs_bound (v n1 * v n2) (b1 * b2)
-#pop-options
-
-val lemma_add_i16b (b1 b2:nat) (n1 n2:i16) :
- Lemma (requires (is_i16b b1 n1 /\ is_i16b b2 n2 /\ b1 + b2 < pow2 15))
- (ensures (range (v n1 + v n2) i16_inttype /\
- is_i16b (b1 + b2) (n1 +! n2)))
-let lemma_add_i16b (b1 b2:nat) (n1 n2:i16) = ()
-
-#push-options "--z3rlimit 100 --split_queries always"
-let lemma_range_at_percent (v:int) (p:int{p>0/\ p%2=0 /\ v < p/2 /\ v >= -p / 2}):
- Lemma (v @% p == v) =
- let m = v % p in
- if v < 0 then (
- Math.Lemmas.lemma_mod_plus v 1 p;
- assert ((v + p) % p == v % p);
- assert (v + p >= 0);
- assert (v + p < p);
- Math.Lemmas.modulo_lemma (v+p) p;
- assert (m == v + p);
- assert (m >= p/2);
- assert (v @% p == m - p);
- assert (v @% p == v))
- else (
- assert (v >= 0 /\ v < p);
- Math.Lemmas.modulo_lemma v p;
- assert (v % p == v);
- assert (m < p/2);
- assert (v @% p == v)
- )
-#pop-options
-
-val lemma_sub_i16b (b1 b2:nat) (n1 n2:i16) :
- Lemma (requires (is_i16b b1 n1 /\ is_i16b b2 n2 /\ b1 + b2 < pow2 15))
- (ensures (range (v n1 - v n2) i16_inttype /\
- is_i16b (b1 + b2) (n1 -. n2) /\
- v (n1 -. n2) == v n1 - v n2))
-let lemma_sub_i16b (b1 b2:nat) (n1 n2:i16) = ()
-
-let mont_mul_red_i16 (x:i16) (y:i16) : i16=
- let vlow = x *. y in
- let k = vlow *. (neg 3327s) in
- let k_times_modulus = cast (((cast k <: i32) *. 3329l) >>! 16l) <: i16 in
- let vhigh = cast (((cast x <: i32) *. (cast y <: i32)) >>! 16l) <: i16 in
- vhigh -. k_times_modulus
-
-let mont_red_i32 (x:i32) : i16 =
- let vlow = cast x <: i16 in
- let k = vlow *. (neg 3327s) in
- let k_times_modulus = cast (((cast k <: i32) *. 3329l) >>! 16l) <: i16 in
- let vhigh = cast (x >>! 16l) <: i16 in
- vhigh -. k_times_modulus
-
-#push-options "--z3rlimit 100"
-let lemma_at_percent_mod (v:int) (p:int{p>0/\ p%2=0}):
- Lemma ((v @% p) % p == v % p) =
- let m = v % p in
- assert (m >= 0 /\ m < p);
- if m >= p/2 then (
- assert ((v @%p) % p == (m - p) %p);
- Math.Lemmas.lemma_mod_plus m (-1) p;
- assert ((v @%p) % p == m %p);
- Math.Lemmas.lemma_mod_mod m v p;
- assert ((v @%p) % p == v % p)
- ) else (
- assert ((v @%p) % p == m%p);
- Math.Lemmas.lemma_mod_mod m v p;
- assert ((v @%p) % p == v % p)
- )
-#pop-options
-
-let lemma_div_at_percent (v:int) (p:int{p>0/\ p%2=0 /\ (v/p) < p/2 /\ (v/p) >= -p / 2}):
- Lemma ((v / p) @% p == v / p) =
- lemma_range_at_percent (v/p) p
-
-val lemma_mont_red_i32 (x:i32): Lemma
- (requires (is_i32b (3328 * pow2 16) x))
- (ensures (
- let result:i16 = mont_red_i32 x in
- is_i16b (3328 + 1665) result /\
- (is_i32b (3328 * pow2 15) x ==> is_i16b 3328 result) /\
- v result % 3329 == (v x * 169) % 3329))
-
-let lemma_mont_red_i32 (x:i32) =
- let vlow = cast x <: i16 in
- assert (v vlow == v x @% pow2 16);
- let k = vlow *. (neg 3327s) in
- assert (v k == ((v x @% pow2 16) * (- 3327)) @% pow2 16);
- let k_times_modulus = (cast k <: i32) *. 3329l in
- assert (v k_times_modulus == (v k * 3329));
- let c = cast (k_times_modulus >>! 16l) <: i16 in
- assert (v c == (((v k * 3329) / pow2 16) @% pow2 16));
- lemma_div_at_percent (v k * 3329) (pow2 16);
- assert (v c == (((v k * 3329) / pow2 16)));
- assert (is_i16b 1665 c);
- let vhigh = cast (x >>! 16l) <: i16 in
- lemma_div_at_percent (v x) (pow2 16);
- assert (v vhigh == v x / pow2 16);
- assert (is_i16b 3328 vhigh);
- let result = vhigh -. c in
- lemma_sub_i16b 3328 1665 vhigh c;
- assert (is_i16b (3328 + 1665) result);
- assert (v result = v vhigh - v c);
- assert (is_i16b (3328 + 1665) result);
- assert (is_i32b (3328 * pow2 15) x ==> is_i16b 3328 result);
- calc ( == ) {
- v k_times_modulus % pow2 16;
- ( == ) { assert (v k_times_modulus == v k * 3329) }
- (v k * 3329) % pow2 16;
- ( == ) { assert (v k = ((v x @% pow2 16) * (-3327)) @% pow2 16) }
- ((((v x @% pow2 16) * (-3327)) @% pow2 16) * 3329) % pow2 16;
- ( == ) { Math.Lemmas.lemma_mod_mul_distr_l (((v x @% pow2 16) * (-3327)) @% pow2 16) 3329 (pow2 16) }
- (((((v x @% pow2 16) * (-3327)) @% pow2 16) % pow2 16) * 3329) % pow2 16;
- ( == ) { lemma_at_percent_mod ((v x @% pow2 16) * (-3327)) (pow2 16)}
- ((((v x @% pow2 16) * (-3327)) % pow2 16) * 3329) % pow2 16;
- ( == ) { Math.Lemmas.lemma_mod_mul_distr_l ((v x @% pow2 16) * (-3327)) 3329 (pow2 16) }
- (((v x @% pow2 16) * (-3327)) * 3329) % pow2 16;
- ( == ) { }
- ((v x @% pow2 16) * (-3327 * 3329)) % pow2 16;
- ( == ) { Math.Lemmas.lemma_mod_mul_distr_r (v x @% pow2 16) (-3327 * 3329) (pow2 16) }
- ((v x @% pow2 16) % pow2 16);
- ( == ) { lemma_at_percent_mod (v x) (pow2 16) }
- (v x) % pow2 16;
- };
- Math.Lemmas.modulo_add (pow2 16) (- (v k_times_modulus)) (v x) (v k_times_modulus);
- assert ((v x - v k_times_modulus) % pow2 16 == 0);
- calc ( == ) {
- v result % 3329;
- ( == ) { }
- (v x / pow2 16 - v k_times_modulus / pow2 16) % 3329;
- ( == ) { Math.Lemmas.lemma_div_exact (v x - v k_times_modulus) (pow2 16) }
- ((v x - v k_times_modulus) / pow2 16) % 3329;
- ( == ) { assert ((pow2 16 * 169) % 3329 == 1) }
- (((v x - v k_times_modulus) / pow2 16) * ((pow2 16 * 169) % 3329)) % 3329;
- ( == ) { Math.Lemmas.lemma_mod_mul_distr_r ((v x - v k_times_modulus) / pow2 16)
- (pow2 16 * 169)
- 3329 }
- (((v x - v k_times_modulus) / pow2 16) * pow2 16 * 169) % 3329;
- ( == ) { Math.Lemmas.lemma_div_exact (v x - v k_times_modulus) (pow2 16) }
- ((v x - v k_times_modulus) * 169) % 3329;
- ( == ) { assert (v k_times_modulus == v k * 3329) }
- ((v x * 169) - (v k * 3329 * 169)) % 3329;
- ( == ) { Math.Lemmas.lemma_mod_sub (v x * 169) 3329 (v k * 169) }
- (v x * 169) % 3329;
- }
-
-val lemma_mont_mul_red_i16_int (x y:i16): Lemma
- (requires (is_intb (3326 * pow2 15) (v x * v y)))
- (ensures (
- let result:i16 = mont_mul_red_i16 x y in
- is_i16b 3328 result /\
- v result % 3329 == (v x * v y * 169) % 3329))
-
-let lemma_mont_mul_red_i16_int (x y:i16) =
- let vlow = x *. y in
- let prod = v x * v y in
- assert (v vlow == prod @% pow2 16);
- let k = vlow *. (neg 3327s) in
- assert (v k == (((prod) @% pow2 16) * (- 3327)) @% pow2 16);
- let k_times_modulus = (cast k <: i32) *. 3329l in
- assert (v k_times_modulus == (v k * 3329));
- let c = cast (k_times_modulus >>! 16l) <: i16 in
- assert (v c == (((v k * 3329) / pow2 16) @% pow2 16));
- lemma_div_at_percent (v k * 3329) (pow2 16);
- assert (v c == (((v k * 3329) / pow2 16)));
- assert (is_i16b 1665 c);
- let vhigh = cast (((cast x <: i32) *. (cast y <: i32)) >>! 16l) <: i16 in
- assert (v x @% pow2 32 == v x);
- assert (v y @% pow2 32 == v y);
- assert (v ((cast x <: i32) *. (cast y <: i32)) == (v x * v y) @% pow2 32);
- assert (v vhigh == (((prod) @% pow2 32) / pow2 16) @% pow2 16);
- assert_norm (pow2 15 * 3326 < pow2 31);
- lemma_range_at_percent prod (pow2 32);
- assert (v vhigh == (prod / pow2 16) @% pow2 16);
- lemma_div_at_percent prod (pow2 16);
- assert (v vhigh == prod / pow2 16);
- let result = vhigh -. c in
- assert (is_i16b 1663 vhigh);
- lemma_sub_i16b 1663 1665 vhigh c;
- assert (is_i16b 3328 result);
- assert (v result = v vhigh - v c);
- calc ( == ) {
- v k_times_modulus % pow2 16;
- ( == ) { assert (v k_times_modulus == v k * 3329) }
- (v k * 3329) % pow2 16;
- ( == ) { assert (v k = ((prod @% pow2 16) * (-3327)) @% pow2 16) }
- ((((prod @% pow2 16) * (-3327)) @% pow2 16) * 3329) % pow2 16;
- ( == ) { Math.Lemmas.lemma_mod_mul_distr_l (((prod @% pow2 16) * (-3327)) @% pow2 16) 3329 (pow2 16) }
- (((((prod @% pow2 16) * (-3327)) @% pow2 16) % pow2 16) * 3329) % pow2 16;
- ( == ) { lemma_at_percent_mod ((prod @% pow2 16) * (-3327)) (pow2 16)}
- ((((prod @% pow2 16) * (-3327)) % pow2 16) * 3329) % pow2 16;
- ( == ) { Math.Lemmas.lemma_mod_mul_distr_l ((prod @% pow2 16) * (-3327)) 3329 (pow2 16) }
- (((prod @% pow2 16) * (-3327)) * 3329) % pow2 16;
- ( == ) { }
- ((prod @% pow2 16) * (-3327 * 3329)) % pow2 16;
- ( == ) { Math.Lemmas.lemma_mod_mul_distr_r (prod @% pow2 16) (-3327 * 3329) (pow2 16) }
- ((prod @% pow2 16) % pow2 16);
- ( == ) { lemma_at_percent_mod (prod) (pow2 16) }
- (prod) % pow2 16;
- };
- Math.Lemmas.modulo_add (pow2 16) (- (v k_times_modulus)) ((prod)) (v k_times_modulus);
- assert (((prod) - v k_times_modulus) % pow2 16 == 0);
- calc ( == ) {
- v result % 3329;
- ( == ) { }
- (((prod) / pow2 16) - ((v k * 3329) / pow2 16)) % 3329;
- ( == ) { Math.Lemmas.lemma_div_exact ((prod) - (v k * 3329)) (pow2 16) }
- ((prod - (v k * 3329)) / pow2 16) % 3329;
- ( == ) { assert ((pow2 16 * 169) % 3329 == 1) }
- (((prod - (v k * 3329)) / pow2 16) * ((pow2 16 * 169) % 3329)) % 3329;
- ( == ) { Math.Lemmas.lemma_mod_mul_distr_r (((prod) - (v k * 3329)) / pow2 16)
- (pow2 16 * 169)
- 3329 }
- ((((prod) - (v k * 3329)) / pow2 16) * pow2 16 * 169) % 3329;
- ( == ) { Math.Lemmas.lemma_div_exact ((prod) - (v k * 3329)) (pow2 16) }
- (((prod) - (v k * 3329)) * 169) % 3329;
- ( == ) { Math.Lemmas.lemma_mod_sub ((prod) * 169) 3329 (v k * 169)}
- ((prod) * 169) % 3329;
- }
-
-
-val lemma_mont_mul_red_i16 (x y:i16): Lemma
- (requires (is_i16b 1664 y \/ is_intb (3326 * pow2 15) (v x * v y)))
- (ensures (
- let result:i16 = mont_mul_red_i16 x y in
- is_i16b 3328 result /\
- v result % 3329 == (v x * v y * 169) % 3329))
- [SMTPat (mont_mul_red_i16 x y)]
-let lemma_mont_mul_red_i16 x y =
- if is_i16b 1664 y then (
- lemma_mul_intb (pow2 15) 1664 (v x) (v y);
- assert(is_intb (3326 * pow2 15) (v x * v y));
- lemma_mont_mul_red_i16_int x y)
- else lemma_mont_mul_red_i16_int x y
-
-let barrett_red (x:i16) =
- let t1 = cast (((cast x <: i32) *. (cast 20159s <: i32)) >>! 16l) <: i16 in
- let t2 = t1 +. 512s in
- let q = t2 >>! 10l in
- let qm = q *. 3329s in
- x -. qm
-
-let lemma_barrett_red (x:i16) : Lemma
- (requires (is_i16b 28296 x))
- (ensures (let result = barrett_red x in
- is_i16b 3328 result /\
- v result % 3329 == v x % 3329))
- [SMTPat (barrett_red x)]
- = admit()
-
-let cond_sub (x:i16) =
- let xm = x -. 3329s in
- let mask = xm >>! 15l in
- let mm = mask &. 3329s in
- xm +. mm
-
-let lemma_cond_sub x:
- Lemma (let r = cond_sub x in
- if x >=. 3329s then r == x -! 3329s else r == x)
- [SMTPat (cond_sub x)]
- = admit()
-
-
-let lemma_shift_right_15_i16 (x:i16):
- Lemma (if v x >= 0 then (x >>! 15l) == 0s else (x >>! 15l) == -1s) =
- Rust_primitives.Integers.mk_int_v_lemma #i16_inttype 0s;
- Rust_primitives.Integers.mk_int_v_lemma #i16_inttype (-1s);
- ()
-
-val ntt_spec #len (vec_in: t_Array i16 len) (zeta: int) (i: nat{i < v len}) (j: nat{j < v len})
- (vec_out: t_Array i16 len) : Type0
-let ntt_spec vec_in zeta i j vec_out =
- ((v (Seq.index vec_out i) % 3329) ==
- ((v (Seq.index vec_in i) + (v (Seq.index vec_in j) * zeta * 169)) % 3329)) /\
- ((v (Seq.index vec_out j) % 3329) ==
- ((v (Seq.index vec_in i) - (v (Seq.index vec_in j) * zeta * 169)) % 3329))
-
-val inv_ntt_spec #len (vec_in: t_Array i16 len) (zeta: int) (i: nat{i < v len}) (j: nat{j < v len})
- (vec_out: t_Array i16 len) : Type0
-let inv_ntt_spec vec_in zeta i j vec_out =
- ((v (Seq.index vec_out i) % 3329) ==
- ((v (Seq.index vec_in j) + v (Seq.index vec_in i)) % 3329)) /\
- ((v (Seq.index vec_out j) % 3329) ==
- (((v (Seq.index vec_in j) - v (Seq.index vec_in i)) * zeta * 169) % 3329))
-