Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add spec for Ind_cpa unpacked functions #612

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ val v_PRF (v_LEN: usize) (input: t_Slice u8)
result == Spec.Utils.v_PRF v_LEN input)

val v_PRFxN (v_K v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K)
: Prims.Pure (t_Array (t_Array u8 v_LEN) v_K) Prims.l_True (fun _ -> Prims.l_True)
: Prims.Pure (t_Array (t_Array u8 v_LEN) v_K)
(requires v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4))
(ensures
fun result ->
let result:t_Array (t_Array u8 v_LEN) v_K = result in
result == Spec.Utils.v_PRFxN v_K v_LEN input)

/// The state.
/// It\'s only used for SHAKE128.
Expand Down Expand Up @@ -63,15 +68,19 @@ let impl (v_K: usize) : Libcrux_ml_kem.Hash_functions.t_Hash t_Simd256Hash v_K =
(fun (v_LEN: usize) (input: t_Slice u8) (out: t_Array u8 v_LEN) ->
v v_LEN < pow2 32 ==> out == Spec.Utils.v_PRF v_LEN input);
f_PRF = (fun (v_LEN: usize) (input: t_Slice u8) -> v_PRF v_LEN input);
f_PRFxN_pre = (fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> true);
f_PRFxN_pre
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) ->
v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4));
f_PRFxN_post
=
(fun
(v_LEN: usize)
(input: t_Array (t_Array u8 (sz 33)) v_K)
(out: t_Array (t_Array u8 v_LEN) v_K)
->
true);
(v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4)) ==>
out == Spec.Utils.v_PRFxN v_K v_LEN input);
f_PRFxN
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> v_PRFxN v_K v_LEN input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ val v_PRF (v_LEN: usize) (input: t_Slice u8)
result == Spec.Utils.v_PRF v_LEN input)

val v_PRFxN (v_K v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K)
: Prims.Pure (t_Array (t_Array u8 v_LEN) v_K) Prims.l_True (fun _ -> Prims.l_True)
: Prims.Pure (t_Array (t_Array u8 v_LEN) v_K)
(requires v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4))
(ensures
fun result ->
let result:t_Array (t_Array u8 v_LEN) v_K = result in
result == Spec.Utils.v_PRFxN v_K v_LEN input)

/// The state.
/// It\'s only used for SHAKE128.
Expand Down Expand Up @@ -63,15 +68,19 @@ let impl (v_K: usize) : Libcrux_ml_kem.Hash_functions.t_Hash t_Simd128Hash v_K =
(fun (v_LEN: usize) (input: t_Slice u8) (out: t_Array u8 v_LEN) ->
v v_LEN < pow2 32 ==> out == Spec.Utils.v_PRF v_LEN input);
f_PRF = (fun (v_LEN: usize) (input: t_Slice u8) -> v_PRF v_LEN input);
f_PRFxN_pre = (fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> true);
f_PRFxN_pre
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) ->
v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4));
f_PRFxN_post
=
(fun
(v_LEN: usize)
(input: t_Array (t_Array u8 (sz 33)) v_K)
(out: t_Array (t_Array u8 v_LEN) v_K)
->
true);
(v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4)) ==>
out == Spec.Utils.v_PRFxN v_K v_LEN input);
f_PRFxN
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> v_PRFxN v_K v_LEN input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ val v_PRF (v_LEN: usize) (input: t_Slice u8)
result == Spec.Utils.v_PRF v_LEN input)

val v_PRFxN (v_K v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K)
: Prims.Pure (t_Array (t_Array u8 v_LEN) v_K) Prims.l_True (fun _ -> Prims.l_True)
: Prims.Pure (t_Array (t_Array u8 v_LEN) v_K)
(requires v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4))
(ensures
fun result ->
let result:t_Array (t_Array u8 v_LEN) v_K = result in
result == Spec.Utils.v_PRFxN v_K v_LEN input)

/// The state.
/// It\'s only used for SHAKE128.
Expand Down Expand Up @@ -63,15 +68,19 @@ let impl (v_K: usize) : Libcrux_ml_kem.Hash_functions.t_Hash (t_PortableHash v_K
(fun (v_LEN: usize) (input: t_Slice u8) (out: t_Array u8 v_LEN) ->
v v_LEN < pow2 32 ==> out == Spec.Utils.v_PRF v_LEN input);
f_PRF = (fun (v_LEN: usize) (input: t_Slice u8) -> v_PRF v_LEN input);
f_PRFxN_pre = (fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> true);
f_PRFxN_pre
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) ->
v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4));
f_PRFxN_post
=
(fun
(v_LEN: usize)
(input: t_Array (t_Array u8 (sz 33)) v_K)
(out: t_Array (t_Array u8 v_LEN) v_K)
->
true);
(v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4)) ==>
out == Spec.Utils.v_PRFxN v_K v_LEN input);
f_PRFxN
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> v_PRFxN v_K v_LEN input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,17 @@ class t_Hash (v_Self: Type0) (v_K: usize) = {
-> pred: Type0{pred ==> v v_LEN < pow2 32 ==> result == Spec.Utils.v_PRF v_LEN input};
f_PRF:v_LEN: usize -> x0: t_Slice u8
-> Prims.Pure (t_Array u8 v_LEN) (f_PRF_pre v_LEN x0) (fun result -> f_PRF_post v_LEN x0 result);
f_PRFxN_pre:v_LEN: usize -> input: t_Array (t_Array u8 (sz 33)) v_K -> pred: Type0{true ==> pred};
f_PRFxN_post:v_LEN: usize -> t_Array (t_Array u8 (sz 33)) v_K -> t_Array (t_Array u8 v_LEN) v_K
-> Type0;
f_PRFxN_pre:v_LEN: usize -> input: t_Array (t_Array u8 (sz 33)) v_K
-> pred: Type0{v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4) ==> pred};
f_PRFxN_post:
v_LEN: usize ->
input: t_Array (t_Array u8 (sz 33)) v_K ->
result: t_Array (t_Array u8 v_LEN) v_K
-> pred:
Type0
{ pred ==>
(v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4)) ==>
result == Spec.Utils.v_PRFxN v_K v_LEN input };
f_PRFxN:v_LEN: usize -> x0: t_Array (t_Array u8 (sz 33)) v_K
-> Prims.Pure (t_Array (t_Array u8 v_LEN) v_K)
(f_PRFxN_pre v_LEN x0)
Expand Down
103 changes: 76 additions & 27 deletions libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Ind_cpa.fst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ let _ =
let open Libcrux_ml_kem.Vector.Traits in
()

#push-options "--max_fuel 10 --z3rlimit 1000 --ext context_pruning --z3refresh --split_queries always"

let sample_ring_element_cbd
(v_K v_ETA2_RANDOMNESS_SIZE v_ETA2: usize)
(#v_Vector #v_Hasher: Type0)
Expand All @@ -35,13 +37,22 @@ let sample_ring_element_cbd
in
let prf_inputs:t_Array (t_Array u8 (sz 33)) v_K = Rust_primitives.Hax.repeat prf_input v_K in
let v__domain_separator_init:u8 = domain_separator in
let v__prf_inputs_init:t_Array (t_Array u8 (sz 33)) v_K = prf_inputs in
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) =
Rust_primitives.Hax.Folds.fold_range (sz 0)
v_K
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
let i:usize = i in
v domain_separator == v v__domain_separator_init + v i)
v domain_separator == v v__domain_separator_init + v i /\
(v i < v v_K ==>
(forall (j: nat).
(j >= v i /\ j < v v_K) ==> prf_inputs.[ sz j ] == v__prf_inputs_init.[ sz j ])) /\
(forall (j: nat).
j < v i ==>
v (Seq.index (Seq.index prf_inputs j) 32) == v v__domain_separator_init + j /\
Seq.slice (Seq.index prf_inputs j) 0 32 ==
Seq.slice (Seq.index v__prf_inputs_init j) 0 32))
(domain_separator, prf_inputs <: (u8 & t_Array (t_Array u8 (sz 33)) v_K))
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
Expand All @@ -60,6 +71,28 @@ let sample_ring_element_cbd
let domain_separator:u8 = domain_separator +! 1uy in
domain_separator, prf_inputs <: (u8 & t_Array (t_Array u8 (sz 33)) v_K))
in
let _:Prims.unit =
let lemma_aux (i: nat{i < v v_K})
: Lemma
(prf_inputs.[ sz i ] ==
(Seq.append (Seq.slice prf_input 0 32)
(Seq.create 1
(mk_int #u8_inttype (v (v__domain_separator_init +! (mk_int #u8_inttype i))))))) =
Lib.Sequence.eq_intro #u8
#33
prf_inputs.[ sz i ]
(Seq.append (Seq.slice prf_input 0 32)
(Seq.create 1 (mk_int #u8_inttype (v v__domain_separator_init + i))))
in
Classical.forall_intro lemma_aux;
Lib.Sequence.eq_intro #(t_Array u8 (sz 33))
#(v v_K)
prf_inputs
(createi v_K
(Spec.MLKEM.sample_vector_cbd2_prf_input #v_K
(Seq.slice prf_input 0 32)
(sz (v v__domain_separator_init))))
in
let (prf_outputs: t_Array (t_Array u8 v_ETA2_RANDOMNESS_SIZE) v_K):t_Array
(t_Array u8 v_ETA2_RANDOMNESS_SIZE) v_K =
Libcrux_ml_kem.Hash_functions.f_PRFxN #v_Hasher
Expand All @@ -71,37 +104,45 @@ let sample_ring_element_cbd
let error_1_:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K =
Rust_primitives.Hax.Folds.fold_range (sz 0)
v_K
(fun error_1_ temp_1_ ->
(fun error_1_ i ->
let error_1_:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K =
error_1_
in
let _:usize = temp_1_ in
true)
let i:usize = i in
forall (j: nat).
j < v i ==>
Libcrux_ml_kem.Polynomial.to_spec_poly_t #v_Vector error_1_.[ sz j ] ==
Spec.MLKEM.sample_poly_cbd v_ETA2 prf_outputs.[ sz j ])
error_1_
(fun error_1_ i ->
let error_1_:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K =
error_1_
in
let i:usize = i in
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize error_1_
i
(Libcrux_ml_kem.Sampling.sample_from_binomial_distribution v_ETA2
#v_Vector
(prf_outputs.[ i ] <: t_Slice u8)
<:
Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
<:
t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K)
in
let result:(t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K & u8) =
error_1_, domain_separator
<:
(t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K & u8)
let error_1_:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize error_1_
i
(Libcrux_ml_kem.Sampling.sample_from_binomial_distribution v_ETA2
#v_Vector
(prf_outputs.[ i ] <: t_Slice u8)
<:
Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
in
error_1_)
in
let _:Prims.unit = admit () (* Panic freedom *) in
result
let _:Prims.unit =
Lib.Sequence.eq_intro #(Spec.MLKEM.polynomial)
#(v v_K)
(Libcrux_ml_kem.Polynomial.to_spec_vector_t #v_K #v_Vector error_1_)
(Spec.MLKEM.sample_vector_cbd2 #v_K
(Seq.slice prf_input 0 32)
(sz (v v__domain_separator_init)))
in
error_1_, domain_separator
<:
(t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K & u8)

#push-options "--admit_smt_queries true"
#pop-options

let sample_vector_cbd_then_ntt
(v_K v_ETA v_ETA_RANDOMNESS_SIZE: usize)
Expand All @@ -124,7 +165,9 @@ let sample_vector_cbd_then_ntt
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
let i:usize = i in
v domain_separator == v v__domain_separator_init + v i)
v domain_separator == v v__domain_separator_init + v i /\
(forall (j: nat). j < v i ==> v (Seq.index prf_input j) == v v__domain_separator_init + j)
)
(domain_separator, prf_inputs <: (u8 & t_Array (t_Array u8 (sz 33)) v_K))
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
Expand Down Expand Up @@ -185,13 +228,13 @@ let sample_vector_cbd_then_ntt
in
re_as_ntt)
in
let hax_temp_output:u8 = domain_separator in
let result:u8 = domain_separator in
let _:Prims.unit = admit () (* Panic freedom *) in
let hax_temp_output:u8 = result in
re_as_ntt, hax_temp_output
<:
(t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K & u8)

#pop-options

let sample_vector_cbd_then_ntt_out
(v_K v_ETA v_ETA_RANDOMNESS_SIZE: usize)
(#v_Vector #v_Hasher: Type0)
Expand Down Expand Up @@ -577,7 +620,11 @@ let decrypt_unpacked
secret_key.Libcrux_ml_kem.Ind_cpa.Unpacked.f_secret_as_ntt
u_as_ntt
in
Libcrux_ml_kem.Serialize.compress_then_serialize_message #v_Vector message
let result:t_Array u8 (sz 32) =
Libcrux_ml_kem.Serialize.compress_then_serialize_message #v_Vector message
in
let _:Prims.unit = admit () (* Panic freedom *) in
result

let decrypt
(v_K v_CIPHERTEXT_SIZE v_VECTOR_U_ENCODED_SIZE v_U_COMPRESSION_FACTOR v_V_COMPRESSION_FACTOR:
Expand Down Expand Up @@ -717,7 +764,9 @@ let encrypt_unpacked
<:
t_Slice u8)
in
ciphertext
let result:t_Array u8 v_CIPHERTEXT_SIZE = ciphertext in
let _:Prims.unit = admit () (* Panic freedom *) in
result

#pop-options

Expand Down
39 changes: 35 additions & 4 deletions libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Ind_cpa.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,17 @@ val sample_ring_element_cbd
: Prims.Pure (t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K & u8)
(requires
Spec.MLKEM.is_rank v_K /\ v_ETA2_RANDOMNESS_SIZE == Spec.MLKEM.v_ETA2_RANDOMNESS_SIZE v_K /\
v_ETA2 == Spec.MLKEM.v_ETA2 v_K /\ range (v domain_separator + v v_K) u8_inttype)
(fun _ -> Prims.l_True)
v_ETA2 == Spec.MLKEM.v_ETA2 v_K /\ v domain_separator < 2 * v v_K /\
range (v domain_separator + v v_K) u8_inttype)
(ensures
fun temp_0_ ->
let err1, ds:(t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K &
u8) =
temp_0_
in
v ds == v domain_separator + v v_K /\
Libcrux_ml_kem.Polynomial.to_spec_vector_t #v_K #v_Vector err1 ==
Spec.MLKEM.sample_vector_cbd2 #v_K (Seq.slice prf_input 0 32) (sz (v domain_separator)))

/// Sample a vector of ring elements from a centered binomial distribution and
/// convert them into their NTT representations.
Expand Down Expand Up @@ -233,7 +242,13 @@ val decrypt_unpacked
v_U_COMPRESSION_FACTOR == Spec.MLKEM.v_VECTOR_U_COMPRESSION_FACTOR v_K /\
v_V_COMPRESSION_FACTOR == Spec.MLKEM.v_VECTOR_V_COMPRESSION_FACTOR v_K /\
v_VECTOR_U_ENCODED_SIZE == Spec.MLKEM.v_C1_SIZE v_K)
(fun _ -> Prims.l_True)
(ensures
fun result ->
let result:t_Array u8 (sz 32) = result in
result ==
Spec.MLKEM.ind_cpa_decrypt_unpacked v_K
ciphertext
(Libcrux_ml_kem.Polynomial.to_spec_vector_t #v_K #v_Vector secret_key.f_secret_as_ntt))

val decrypt
(v_K v_CIPHERTEXT_SIZE v_VECTOR_U_ENCODED_SIZE v_U_COMPRESSION_FACTOR v_V_COMPRESSION_FACTOR:
Expand Down Expand Up @@ -310,7 +325,15 @@ val encrypt_unpacked
v_BLOCK_LEN == Spec.MLKEM.v_C1_BLOCK_SIZE v_K /\
v_CIPHERTEXT_SIZE == Spec.MLKEM.v_CPA_CIPHERTEXT_SIZE v_K /\
length randomness == Spec.MLKEM.v_SHARED_SECRET_SIZE)
(fun _ -> Prims.l_True)
(ensures
fun result ->
let result:t_Array u8 v_CIPHERTEXT_SIZE = result in
result ==
Spec.MLKEM.ind_cpa_encrypt_unpacked v_K
message
randomness
(Libcrux_ml_kem.Polynomial.to_spec_vector_t #v_K #v_Vector public_key.f_t_as_ntt)
(Libcrux_ml_kem.Polynomial.to_spec_matrix_t #v_K #v_Vector public_key.f_A))

val encrypt
(v_K v_CIPHERTEXT_SIZE v_T_AS_NTT_ENCODED_SIZE v_C1_LEN v_C2_LEN v_U_COMPRESSION_FACTOR v_V_COMPRESSION_FACTOR v_BLOCK_LEN v_ETA1 v_ETA1_RANDOMNESS_SIZE v_ETA2 v_ETA2_RANDOMNESS_SIZE:
Expand Down Expand Up @@ -396,6 +419,14 @@ val generate_keypair_unpacked
Libcrux_ml_kem.Ind_cpa.Unpacked.t_IndCpaPublicKeyUnpacked v_K v_Vector) =
temp_0_
in
let ((t_as_ntt, seed_for_A), secret_as_ntt), valid =
Spec.MLKEM.ind_cpa_generate_keypair_unpacked v_K key_generation_seed
in
(valid ==>
((Libcrux_ml_kem.Polynomial.to_spec_vector_t #v_K #v_Vector public_key.f_t_as_ntt) ==
t_as_ntt) /\ (public_key.f_seed_for_A == seed_for_A) /\
((Libcrux_ml_kem.Polynomial.to_spec_vector_t #v_K #v_Vector private_key.f_secret_as_ntt) ==
secret_as_ntt)) /\
(forall (i: nat).
i < v v_K ==>
Libcrux_ml_kem.Serialize.coefficients_field_modulus_range (Seq.index private_key_future
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ val inv_ntt_layer_int_vec_step_reduce
: Prims.Pure (v_Vector & v_Vector)
(requires
Spec.Utils.is_i16b 1664 zeta_r /\
(forall i.
(forall (i: nat).
i < 16 ==>
Spec.Utils.is_intb (pow2 15 - 1)
(v (Seq.index (Libcrux_ml_kem.Vector.Traits.f_to_i16_array b) i) -
v (Seq.index (Libcrux_ml_kem.Vector.Traits.f_to_i16_array a) i))) /\
(forall i.
(forall (i: nat).
i < 16 ==>
Spec.Utils.is_intb (pow2 15 - 1)
(v (Seq.index (Libcrux_ml_kem.Vector.Traits.f_to_i16_array a) i) +
Expand Down
Loading
Loading